This commit is contained in:
Glenn Jocher 2019-08-05 17:41:25 +02:00
parent e1c407dab1
commit 1613d1c396
1 changed files with 15 additions and 11 deletions

View File

@ -337,21 +337,25 @@ def build_targets(model, targets):
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
nt = len(targets)
txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], []
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
for i in model.yolo_layers:
layer = model.module.module_list[i] if multi_gpu else model.module_list[i]
# get number of grid points and anchor vec for this yolo layer
if multi_gpu:
ng, anchor_vec = model.module.module_list[i].ng, model.module.module_list[i].anchor_vec
else:
ng, anchor_vec = model.module_list[i].ng, model.module_list[i].anchor_vec
# iou of targets-anchors
t, a = targets, []
gwh = t[:, 4:6] * layer.ng
gwh = t[:, 4:6] * ng
if nt:
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0)
use_best_anchor = False
if use_best_anchor:
iou, a = iou.max(0) # best iou and anchor
else: # use all anchors
na = len(layer.anchor_vec) # number of anchors
na = len(anchor_vec) # number of anchors
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
t = targets.repeat([na, 1])
gwh = gwh.repeat([na, 1])
@ -365,7 +369,7 @@ def build_targets(model, targets):
# Indices
b, c = t[:, :2].long().t() # target image, class
gxy = t[:, 2:4] * layer.ng # grid x, y
gxy = t[:, 2:4] * ng # grid x, y
gi, gj = gxy.long().t() # grid x, y indices
indices.append((b, a, gj, gi))
@ -375,18 +379,18 @@ def build_targets(model, targets):
# GIoU
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
anchor_vec.append(layer.anchor_vec[a])
av.append(anchor_vec[a]) # anchor vec
# Width and height
twh.append(torch.log(gwh / layer.anchor_vec[a])) # wh yolo method
# twh.append((gwh / layer.anchor_vec[a]) ** (1 / 3) / 2) # wh power method
twh.append(torch.log(gwh / anchor_vec[a])) # wh yolo method
# twh.append((gwh / anchor_vec[a]) ** (1 / 3) / 2) # wh power method
# Class
tcls.append(c)
if c.shape[0]: # if any targets
assert c.max() <= layer.nc, 'Target classes exceed model classes'
assert c.max() <= model.nc, 'Target classes exceed model classes'
return txy, twh, tcls, tbox, indices, anchor_vec
return txy, twh, tcls, tbox, indices, av
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):