updates
This commit is contained in:
parent
e1c407dab1
commit
1613d1c396
|
@ -337,21 +337,25 @@ def build_targets(model, targets):
|
||||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
|
|
||||||
nt = len(targets)
|
nt = len(targets)
|
||||||
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
|
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
|
||||||
for i in model.yolo_layers:
|
for i in model.yolo_layers:
|
||||||
layer = model.module.module_list[i] if multi_gpu else model.module_list[i]
|
# get number of grid points and anchor vec for this yolo layer
|
||||||
|
if multi_gpu:
|
||||||
|
ng, anchor_vec = model.module.module_list[i].ng, model.module.module_list[i].anchor_vec
|
||||||
|
else:
|
||||||
|
ng, anchor_vec = model.module_list[i].ng, model.module_list[i].anchor_vec
|
||||||
|
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
t, a = targets, []
|
t, a = targets, []
|
||||||
gwh = t[:, 4:6] * layer.ng
|
gwh = t[:, 4:6] * ng
|
||||||
if nt:
|
if nt:
|
||||||
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
|
iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0)
|
||||||
|
|
||||||
use_best_anchor = False
|
use_best_anchor = False
|
||||||
if use_best_anchor:
|
if use_best_anchor:
|
||||||
iou, a = iou.max(0) # best iou and anchor
|
iou, a = iou.max(0) # best iou and anchor
|
||||||
else: # use all anchors
|
else: # use all anchors
|
||||||
na = len(layer.anchor_vec) # number of anchors
|
na = len(anchor_vec) # number of anchors
|
||||||
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
|
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
|
||||||
t = targets.repeat([na, 1])
|
t = targets.repeat([na, 1])
|
||||||
gwh = gwh.repeat([na, 1])
|
gwh = gwh.repeat([na, 1])
|
||||||
|
@ -365,7 +369,7 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
# Indices
|
# Indices
|
||||||
b, c = t[:, :2].long().t() # target image, class
|
b, c = t[:, :2].long().t() # target image, class
|
||||||
gxy = t[:, 2:4] * layer.ng # grid x, y
|
gxy = t[:, 2:4] * ng # grid x, y
|
||||||
gi, gj = gxy.long().t() # grid x, y indices
|
gi, gj = gxy.long().t() # grid x, y indices
|
||||||
indices.append((b, a, gj, gi))
|
indices.append((b, a, gj, gi))
|
||||||
|
|
||||||
|
@ -375,18 +379,18 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
# GIoU
|
# GIoU
|
||||||
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
||||||
anchor_vec.append(layer.anchor_vec[a])
|
av.append(anchor_vec[a]) # anchor vec
|
||||||
|
|
||||||
# Width and height
|
# Width and height
|
||||||
twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method
|
twh.append(torch.log(gwh / anchor_vec[a])) # wh yolo method
|
||||||
# twh.append((gwh / layer.anchor_vec[a]) ** (1 / 3) / 2) # wh power method
|
# twh.append((gwh / anchor_vec[a]) ** (1 / 3) / 2) # wh power method
|
||||||
|
|
||||||
# Class
|
# Class
|
||||||
tcls.append(c)
|
tcls.append(c)
|
||||||
if c.shape[0]: # if any targets
|
if c.shape[0]: # if any targets
|
||||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
assert c.max() <= model.nc, 'Target classes exceed model classes'
|
||||||
|
|
||||||
return txy, twh, tcls, tbox, indices, anchor_vec
|
return txy, twh, tcls, tbox, indices, av
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
|
|
Loading…
Reference in New Issue