yolov5 regress updates to yolov3 - build_targets()

This commit is contained in:
Glenn Jocher 2020-05-17 15:10:31 -07:00
parent 110ead20e6
commit c8f4ee6c46
1 changed files with 35 additions and 27 deletions

View File

@ -408,49 +408,57 @@ def compute_loss(p, targets, model): # predictions, targets, model
def build_targets(p, targets, model): def build_targets(p, targets, model):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
nt = targets.shape[0] nt = targets.shape[0]
tcls, tbox, indices, anch = [], [], [], [] tcls, tbox, indices, anch = [], [], [], []
reject, use_all_anchors = True, True
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
style = None
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i, j in enumerate(model.yolo_layers): for i, j in enumerate(model.yolo_layers):
# get number of grid points and anchor vec for this yolo layer # get number of grid points and anchor vec for this yolo layer
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
# iou of targets-anchors
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
t, a = targets * gain, []
gwh = t[:, 4:6]
if nt:
iou = wh_iou(anchors, gwh) # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
if use_all_anchors:
na = anchors.shape[0] # number of anchors na = anchors.shape[0] # number of anchors
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1) at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
t = t.repeat(na, 1)
else: # use best anchor only
iou, a = iou.max(0) # best iou and anchor
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R) # Match targets to anchors
if reject: a, t, offsets = [], targets * gain, 0
j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter if nt:
t, a = t[j], a[j] # r = t[None, :, 4:6] / anchors[:, None] # wh ratio
# j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
# Indices # overlaps
b, c = t[:, :2].long().t() # image, class gxy = t[:, 2:4] # grid xy
z = torch.zeros_like(gxy)
if style == 'rect2':
g = 0.2 # offset
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
elif style == 'rect4':
g = 0.5 # offset
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
# Define
b, c = t[:, :2].long().T # image, class
gxy = t[:, 2:4] # grid xy gxy = t[:, 2:4] # grid xy
gwh = t[:, 4:6] # grid wh gwh = t[:, 4:6] # grid wh
gi, gj = gxy.long().t() # grid xy indices gij = (gxy - offsets).long()
indices.append((b, a, gj, gi)) gi, gj = gij.T # grid xy indices
# Box # Append
tbox.append(torch.cat((gxy % 1., gwh), 1)) # xywh (grids) indices.append((b, a, gj, gi)) # image, anchor, grid indices
anch.append(anchors[a]) # anchor vec tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
anch.append(anchors[a]) # anchors
# Class tcls.append(c) # class
tcls.append(c)
if c.shape[0]: # if any targets if c.shape[0]: # if any targets
assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \ assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % ( 'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (