bug fix
This commit is contained in:
parent
835b0da68a
commit
25725c8569
|
@ -438,12 +438,12 @@ def build_targets(p, targets, model):
|
||||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
reject, use_all_anchors = True, True
|
reject, use_all_anchors = True, True
|
||||||
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
||||||
for i in model.yolo_layers:
|
for j, i in enumerate(model.yolo_layers):
|
||||||
# get number of grid points and anchor vec for this yolo layer
|
# get number of grid points and anchor vec for this yolo layer
|
||||||
anchor_vec = model.module.module_list[i].anchor_vec if multi_gpu else model.module_list[i].anchor_vec
|
anchor_vec = model.module.module_list[i].anchor_vec if multi_gpu else model.module_list[i].anchor_vec
|
||||||
|
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
gain[2:] = torch.tensor(p[0].shape)[[2, 3, 2, 3]] # xyxy gain
|
gain[2:] = torch.tensor(p[j].shape)[[2, 3, 2, 3]] # xyxy gain
|
||||||
t, a = targets * gain, []
|
t, a = targets * gain, []
|
||||||
gwh = t[:, 4:6]
|
gwh = t[:, 4:6]
|
||||||
if nt:
|
if nt:
|
||||||
|
@ -452,7 +452,7 @@ def build_targets(p, targets, model):
|
||||||
if use_all_anchors:
|
if use_all_anchors:
|
||||||
na = anchor_vec.shape[0] # number of anchors
|
na = anchor_vec.shape[0] # number of anchors
|
||||||
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
||||||
t = targets.repeat(na, 1)
|
t = t.repeat(na, 1)
|
||||||
else: # use best anchor only
|
else: # use best anchor only
|
||||||
iou, a = iou.max(0) # best iou and anchor
|
iou, a = iou.max(0) # best iou and anchor
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue