bug fix
This commit is contained in:
parent
1681249588
commit
029e137bc2
|
@ -437,13 +437,17 @@ def build_targets(p, targets, model):
|
|||
tcls, tbox, indices, av = [], [], [], []
|
||||
reject, use_all_anchors = True, True
|
||||
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
||||
|
||||
# m = list(model.modules())[-1]
|
||||
# for i in range(m.nl):
|
||||
# anchor_vec = m.anchor_vec[i]
|
||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
for i, j in enumerate(model.yolo_layers):
|
||||
# get number of grid points and anchor vec for this yolo layer
|
||||
anchor_vec = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
||||
|
||||
# iou of targets-anchors
|
||||
gain[2:] = torch.tensor(p[i].shape)[[2, 3, 2, 3]] # xyxy gain
|
||||
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
||||
t, a = targets * gain, []
|
||||
gwh = t[:, 4:6]
|
||||
if nt:
|
||||
|
|
Loading…
Reference in New Issue