This commit is contained in:
Glenn Jocher 2020-04-14 13:08:00 -07:00
parent f5a2682a81
commit ac4c90c817
1 changed files with 7 additions and 7 deletions

View File

@ -440,21 +440,21 @@ def build_targets(p, targets, model):
# m = list(model.modules())[-1] # m = list(model.modules())[-1]
# for i in range(m.nl): # for i in range(m.nl):
# anchor_vec = m.anchor_vec[i] # anchors = m.anchors[i]
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i, j in enumerate(model.yolo_layers): for i, j in enumerate(model.yolo_layers):
# get number of grid points and anchor vec for this yolo layer # get number of grid points and anchor vec for this yolo layer
anchor_vec = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
# iou of targets-anchors # iou of targets-anchors
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
t, a = targets * gain, [] t, a = targets * gain, []
gwh = t[:, 4:6] gwh = t[:, 4:6]
if nt: if nt:
iou = wh_iou(anchor_vec, gwh) # iou(3,n) = wh_iou(anchor_vec(3,2), gwh(n,2)) iou = wh_iou(anchors, gwh) # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
if use_all_anchors: if use_all_anchors:
na = anchor_vec.shape[0] # number of anchors na = anchors.shape[0] # number of anchors
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1) a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
t = t.repeat(na, 1) t = t.repeat(na, 1)
else: # use best anchor only else: # use best anchor only
@ -475,7 +475,7 @@ def build_targets(p, targets, model):
# Box # Box
gxy -= gxy.floor() # xy gxy -= gxy.floor() # xy
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids) tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
av.append(anchor_vec[a]) # anchor vec av.append(anchors[a]) # anchor vec
# Class # Class
tcls.append(c) tcls.append(c)