cleanup
This commit is contained in:
parent
f5a2682a81
commit
ac4c90c817
|
@ -440,21 +440,21 @@ def build_targets(p, targets, model):
|
||||||
|
|
||||||
# m = list(model.modules())[-1]
|
# m = list(model.modules())[-1]
|
||||||
# for i in range(m.nl):
|
# for i in range(m.nl):
|
||||||
# anchor_vec = m.anchor_vec[i]
|
# anchors = m.anchors[i]
|
||||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
for i, j in enumerate(model.yolo_layers):
|
for i, j in enumerate(model.yolo_layers):
|
||||||
# get number of grid points and anchor vec for this yolo layer
|
# get number of grid points and anchor vec for this yolo layer
|
||||||
anchor_vec = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
||||||
|
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
||||||
t, a = targets * gain, []
|
t, a = targets * gain, []
|
||||||
gwh = t[:, 4:6]
|
gwh = t[:, 4:6]
|
||||||
if nt:
|
if nt:
|
||||||
iou = wh_iou(anchor_vec, gwh) # iou(3,n) = wh_iou(anchor_vec(3,2), gwh(n,2))
|
iou = wh_iou(anchors, gwh) # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
|
||||||
|
|
||||||
if use_all_anchors:
|
if use_all_anchors:
|
||||||
na = anchor_vec.shape[0] # number of anchors
|
na = anchors.shape[0] # number of anchors
|
||||||
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
||||||
t = t.repeat(na, 1)
|
t = t.repeat(na, 1)
|
||||||
else: # use best anchor only
|
else: # use best anchor only
|
||||||
|
@ -475,7 +475,7 @@ def build_targets(p, targets, model):
|
||||||
# Box
|
# Box
|
||||||
gxy -= gxy.floor() # xy
|
gxy -= gxy.floor() # xy
|
||||||
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
||||||
av.append(anchor_vec[a]) # anchor vec
|
av.append(anchors[a]) # anchor vec
|
||||||
|
|
||||||
# Class
|
# Class
|
||||||
tcls.append(c)
|
tcls.append(c)
|
||||||
|
|
Loading…
Reference in New Issue