yolov5 regress updates to yolov3 - build_targets()
This commit is contained in:
		
							parent
							
								
									110ead20e6
								
							
						
					
					
						commit
						c8f4ee6c46
					
				|  | @ -408,49 +408,57 @@ def compute_loss(p, targets, model):  # predictions, targets, model | ||||||
| 
 | 
 | ||||||
| def build_targets(p, targets, model): | def build_targets(p, targets, model): | ||||||
|     # Build targets for compute_loss(), input targets(image,class,x,y,w,h) |     # Build targets for compute_loss(), input targets(image,class,x,y,w,h) | ||||||
|  | 
 | ||||||
|     nt = targets.shape[0] |     nt = targets.shape[0] | ||||||
|     tcls, tbox, indices, anch = [], [], [], [] |     tcls, tbox, indices, anch = [], [], [], [] | ||||||
|     reject, use_all_anchors = True, True |  | ||||||
|     gain = torch.ones(6, device=targets.device)  # normalized to gridspace gain |     gain = torch.ones(6, device=targets.device)  # normalized to gridspace gain | ||||||
|  |     off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float()  # overlap offsets | ||||||
| 
 | 
 | ||||||
| 
 |     style = None | ||||||
|     multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |     multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) | ||||||
|     for i, j in enumerate(model.yolo_layers): |     for i, j in enumerate(model.yolo_layers): | ||||||
|         # get number of grid points and anchor vec for this yolo layer |         # get number of grid points and anchor vec for this yolo layer | ||||||
|         anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec |         anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec | ||||||
| 
 |  | ||||||
|         # iou of targets-anchors |  | ||||||
|         gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain |         gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain | ||||||
|         t, a = targets * gain, [] |  | ||||||
|         gwh = t[:, 4:6] |  | ||||||
|         if nt: |  | ||||||
|             iou = wh_iou(anchors, gwh)  # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2)) |  | ||||||
| 
 |  | ||||||
|             if use_all_anchors: |  | ||||||
|         na = anchors.shape[0]  # number of anchors |         na = anchors.shape[0]  # number of anchors | ||||||
|                 a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1) |         at = torch.arange(na).view(na, 1).repeat(1, nt)  # anchor tensor, same as .repeat_interleave(nt) | ||||||
|                 t = t.repeat(na, 1) |  | ||||||
|             else:  # use best anchor only |  | ||||||
|                 iou, a = iou.max(0)  # best iou and anchor |  | ||||||
| 
 | 
 | ||||||
|             # reject anchors below iou_thres (OPTIONAL, increases P, lowers R) |         # Match targets to anchors | ||||||
|             if reject: |         a, t, offsets = [], targets * gain, 0 | ||||||
|                 j = iou.view(-1) > model.hyp['iou_t']  # iou threshold hyperparameter |         if nt: | ||||||
|                 t, a = t[j], a[j] |             # r = t[None, :, 4:6] / anchors[:, None]  # wh ratio | ||||||
|  |             # j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare | ||||||
|  |             j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2)) | ||||||
|  |             a, t = at[j], t.repeat(na, 1, 1)[j]  # filter | ||||||
| 
 | 
 | ||||||
|         # Indices |             # overlaps | ||||||
|         b, c = t[:, :2].long().t()  # image, class |             gxy = t[:, 2:4]  # grid xy | ||||||
|  |             z = torch.zeros_like(gxy) | ||||||
|  |             if style == 'rect2': | ||||||
|  |                 g = 0.2  # offset | ||||||
|  |                 j, k = ((gxy % 1. < g) & (gxy > 1.)).T | ||||||
|  |                 a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0) | ||||||
|  |                 offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g | ||||||
|  | 
 | ||||||
|  |             elif style == 'rect4': | ||||||
|  |                 g = 0.5  # offset | ||||||
|  |                 j, k = ((gxy % 1. < g) & (gxy > 1.)).T | ||||||
|  |                 l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T | ||||||
|  |                 a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0) | ||||||
|  |                 offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g | ||||||
|  | 
 | ||||||
|  |         # Define | ||||||
|  |         b, c = t[:, :2].long().T  # image, class | ||||||
|         gxy = t[:, 2:4]  # grid xy |         gxy = t[:, 2:4]  # grid xy | ||||||
|         gwh = t[:, 4:6]  # grid wh |         gwh = t[:, 4:6]  # grid wh | ||||||
|         gi, gj = gxy.long().t()  # grid xy indices |         gij = (gxy - offsets).long() | ||||||
|         indices.append((b, a, gj, gi)) |         gi, gj = gij.T  # grid xy indices | ||||||
| 
 | 
 | ||||||
|         # Box |         # Append | ||||||
|         tbox.append(torch.cat((gxy % 1., gwh), 1))  # xywh (grids) |         indices.append((b, a, gj, gi))  # image, anchor, grid indices | ||||||
|         anch.append(anchors[a])  # anchor vec |         tbox.append(torch.cat((gxy - gij, gwh), 1))  # box | ||||||
| 
 |         anch.append(anchors[a])  # anchors | ||||||
|         # Class |         tcls.append(c)  # class | ||||||
|         tcls.append(c) |  | ||||||
|         if c.shape[0]:  # if any targets |         if c.shape[0]:  # if any targets | ||||||
|             assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \ |             assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \ | ||||||
|                                        'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % ( |                                        'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % ( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue