yolov5 regress updates to yolov3 - build_targets()
This commit is contained in:
parent
110ead20e6
commit
c8f4ee6c46
|
@ -408,49 +408,57 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
|
||||
def build_targets(p, targets, model):
|
||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
|
||||
nt = targets.shape[0]
|
||||
tcls, tbox, indices, anch = [], [], [], []
|
||||
reject, use_all_anchors = True, True
|
||||
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
||||
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
|
||||
|
||||
|
||||
style = None
|
||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
for i, j in enumerate(model.yolo_layers):
|
||||
# get number of grid points and anchor vec for this yolo layer
|
||||
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
||||
|
||||
# iou of targets-anchors
|
||||
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
||||
t, a = targets * gain, []
|
||||
gwh = t[:, 4:6]
|
||||
na = anchors.shape[0] # number of anchors
|
||||
at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
|
||||
|
||||
# Match targets to anchors
|
||||
a, t, offsets = [], targets * gain, 0
|
||||
if nt:
|
||||
iou = wh_iou(anchors, gwh) # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
|
||||
# r = t[None, :, 4:6] / anchors[:, None] # wh ratio
|
||||
# j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
|
||||
j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
|
||||
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
||||
|
||||
if use_all_anchors:
|
||||
na = anchors.shape[0] # number of anchors
|
||||
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
||||
t = t.repeat(na, 1)
|
||||
else: # use best anchor only
|
||||
iou, a = iou.max(0) # best iou and anchor
|
||||
# overlaps
|
||||
gxy = t[:, 2:4] # grid xy
|
||||
z = torch.zeros_like(gxy)
|
||||
if style == 'rect2':
|
||||
g = 0.2 # offset
|
||||
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
||||
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
||||
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
||||
|
||||
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
|
||||
if reject:
|
||||
j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter
|
||||
t, a = t[j], a[j]
|
||||
elif style == 'rect4':
|
||||
g = 0.5 # offset
|
||||
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
||||
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|
||||
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
|
||||
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
|
||||
|
||||
# Indices
|
||||
b, c = t[:, :2].long().t() # image, class
|
||||
# Define
|
||||
b, c = t[:, :2].long().T # image, class
|
||||
gxy = t[:, 2:4] # grid xy
|
||||
gwh = t[:, 4:6] # grid wh
|
||||
gi, gj = gxy.long().t() # grid xy indices
|
||||
indices.append((b, a, gj, gi))
|
||||
gij = (gxy - offsets).long()
|
||||
gi, gj = gij.T # grid xy indices
|
||||
|
||||
# Box
|
||||
tbox.append(torch.cat((gxy % 1., gwh), 1)) # xywh (grids)
|
||||
anch.append(anchors[a]) # anchor vec
|
||||
|
||||
# Class
|
||||
tcls.append(c)
|
||||
# Append
|
||||
indices.append((b, a, gj, gi)) # image, anchor, grid indices
|
||||
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
||||
anch.append(anchors[a]) # anchors
|
||||
tcls.append(c) # class
|
||||
if c.shape[0]: # if any targets
|
||||
assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \
|
||||
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (
|
||||
|
|
Loading…
Reference in New Issue