bug fix
This commit is contained in:
		
							parent
							
								
									835b0da68a
								
							
						
					
					
						commit
						25725c8569
					
				|  | @ -438,12 +438,12 @@ def build_targets(p, targets, model): | ||||||
|     multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |     multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) | ||||||
|     reject, use_all_anchors = True, True |     reject, use_all_anchors = True, True | ||||||
|     gain = torch.ones(6, device=targets.device)  # normalized to gridspace gain |     gain = torch.ones(6, device=targets.device)  # normalized to gridspace gain | ||||||
|     for i in model.yolo_layers: |     for j, i in enumerate(model.yolo_layers): | ||||||
|         # get number of grid points and anchor vec for this yolo layer |         # get number of grid points and anchor vec for this yolo layer | ||||||
|         anchor_vec = model.module.module_list[i].anchor_vec if multi_gpu else model.module_list[i].anchor_vec |         anchor_vec = model.module.module_list[i].anchor_vec if multi_gpu else model.module_list[i].anchor_vec | ||||||
| 
 | 
 | ||||||
|         # iou of targets-anchors |         # iou of targets-anchors | ||||||
|         gain[2:] = torch.tensor(p[0].shape)[[2, 3, 2, 3]]  # xyxy gain |         gain[2:] = torch.tensor(p[j].shape)[[2, 3, 2, 3]]  # xyxy gain | ||||||
|         t, a = targets * gain, [] |         t, a = targets * gain, [] | ||||||
|         gwh = t[:, 4:6] |         gwh = t[:, 4:6] | ||||||
|         if nt: |         if nt: | ||||||
|  | @ -452,7 +452,7 @@ def build_targets(p, targets, model): | ||||||
|             if use_all_anchors: |             if use_all_anchors: | ||||||
|                 na = anchor_vec.shape[0]  # number of anchors |                 na = anchor_vec.shape[0]  # number of anchors | ||||||
|                 a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1) |                 a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1) | ||||||
|                 t = targets.repeat(na, 1) |                 t = t.repeat(na, 1) | ||||||
|             else:  # use best anchor only |             else:  # use best anchor only | ||||||
|                 iou, a = iou.max(0)  # best iou and anchor |                 iou, a = iou.max(0)  # best iou and anchor | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue