updates
This commit is contained in:
parent
e1c407dab1
commit
1613d1c396
|
@ -337,21 +337,25 @@ def build_targets(model, targets):
|
|||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
|
||||
nt = len(targets)
|
||||
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
|
||||
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
|
||||
for i in model.yolo_layers:
|
||||
layer = model.module.module_list[i] if multi_gpu else model.module_list[i]
|
||||
# get number of grid points and anchor vec for this yolo layer
|
||||
if multi_gpu:
|
||||
ng, anchor_vec = model.module.module_list[i].ng, model.module.module_list[i].anchor_vec
|
||||
else:
|
||||
ng, anchor_vec = model.module_list[i].ng, model.module_list[i].anchor_vec
|
||||
|
||||
# iou of targets-anchors
|
||||
t, a = targets, []
|
||||
gwh = t[:, 4:6] * layer.ng
|
||||
gwh = t[:, 4:6] * ng
|
||||
if nt:
|
||||
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
|
||||
iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0)
|
||||
|
||||
use_best_anchor = False
|
||||
if use_best_anchor:
|
||||
iou, a = iou.max(0) # best iou and anchor
|
||||
else: # use all anchors
|
||||
na = len(layer.anchor_vec) # number of anchors
|
||||
na = len(anchor_vec) # number of anchors
|
||||
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
|
||||
t = targets.repeat([na, 1])
|
||||
gwh = gwh.repeat([na, 1])
|
||||
|
@ -365,7 +369,7 @@ def build_targets(model, targets):
|
|||
|
||||
# Indices
|
||||
b, c = t[:, :2].long().t() # target image, class
|
||||
gxy = t[:, 2:4] * layer.ng # grid x, y
|
||||
gxy = t[:, 2:4] * ng # grid x, y
|
||||
gi, gj = gxy.long().t() # grid x, y indices
|
||||
indices.append((b, a, gj, gi))
|
||||
|
||||
|
@ -375,18 +379,18 @@ def build_targets(model, targets):
|
|||
|
||||
# GIoU
|
||||
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
||||
anchor_vec.append(layer.anchor_vec[a])
|
||||
av.append(anchor_vec[a]) # anchor vec
|
||||
|
||||
# Width and height
|
||||
twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method
|
||||
# twh.append((gwh / layer.anchor_vec[a]) ** (1 / 3) / 2) # wh power method
|
||||
twh.append(torch.log(gwh / anchor_vec[a])) # wh yolo method
|
||||
# twh.append((gwh / anchor_vec[a]) ** (1 / 3) / 2) # wh power method
|
||||
|
||||
# Class
|
||||
tcls.append(c)
|
||||
if c.shape[0]: # if any targets
|
||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||
assert c.max() <= model.nc, 'Target classes exceed model classes'
|
||||
|
||||
return txy, twh, tcls, tbox, indices, anchor_vec
|
||||
return txy, twh, tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||
|
|
Loading…
Reference in New Issue