code cleanup

This commit is contained in:
Glenn Jocher 2020-04-14 04:15:05 -07:00
parent 25725c8569
commit 198a5a591d
1 changed files with 3 additions and 3 deletions

View File

@ -435,12 +435,12 @@ def build_targets(p, targets, model):
nt = targets.shape[0] nt = targets.shape[0]
tcls, tbox, indices, av = [], [], [], [] tcls, tbox, indices, av = [], [], [], []
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
reject, use_all_anchors = True, True reject, use_all_anchors = True, True
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
for j, i in enumerate(model.yolo_layers): multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i, j in enumerate(model.yolo_layers):
# get number of grid points and anchor vec for this yolo layer # get number of grid points and anchor vec for this yolo layer
anchor_vec = model.module.module_list[i].anchor_vec if multi_gpu else model.module_list[i].anchor_vec anchor_vec = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
# iou of targets-anchors # iou of targets-anchors
gain[2:] = torch.tensor(p[j].shape)[[2, 3, 2, 3]] # xyxy gain gain[2:] = torch.tensor(p[j].shape)[[2, 3, 2, 3]] # xyxy gain