yolov5 regress updates to yolov3 - build_targets()
This commit is contained in:
parent
110ead20e6
commit
c8f4ee6c46
|
@ -408,49 +408,57 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
|
|
||||||
def build_targets(p, targets, model):
|
def build_targets(p, targets, model):
|
||||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
|
|
||||||
nt = targets.shape[0]
|
nt = targets.shape[0]
|
||||||
tcls, tbox, indices, anch = [], [], [], []
|
tcls, tbox, indices, anch = [], [], [], []
|
||||||
reject, use_all_anchors = True, True
|
|
||||||
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
||||||
|
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
|
||||||
|
|
||||||
|
style = None
|
||||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
for i, j in enumerate(model.yolo_layers):
|
for i, j in enumerate(model.yolo_layers):
|
||||||
# get number of grid points and anchor vec for this yolo layer
|
# get number of grid points and anchor vec for this yolo layer
|
||||||
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
||||||
|
|
||||||
# iou of targets-anchors
|
|
||||||
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
||||||
t, a = targets * gain, []
|
|
||||||
gwh = t[:, 4:6]
|
|
||||||
if nt:
|
|
||||||
iou = wh_iou(anchors, gwh) # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
|
|
||||||
|
|
||||||
if use_all_anchors:
|
|
||||||
na = anchors.shape[0] # number of anchors
|
na = anchors.shape[0] # number of anchors
|
||||||
a = torch.arange(na).view(-1, 1).repeat(1, nt).view(-1)
|
at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
|
||||||
t = t.repeat(na, 1)
|
|
||||||
else: # use best anchor only
|
|
||||||
iou, a = iou.max(0) # best iou and anchor
|
|
||||||
|
|
||||||
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
|
# Match targets to anchors
|
||||||
if reject:
|
a, t, offsets = [], targets * gain, 0
|
||||||
j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter
|
if nt:
|
||||||
t, a = t[j], a[j]
|
# r = t[None, :, 4:6] / anchors[:, None] # wh ratio
|
||||||
|
# j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
|
||||||
|
j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
|
||||||
|
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
||||||
|
|
||||||
# Indices
|
# overlaps
|
||||||
b, c = t[:, :2].long().t() # image, class
|
gxy = t[:, 2:4] # grid xy
|
||||||
|
z = torch.zeros_like(gxy)
|
||||||
|
if style == 'rect2':
|
||||||
|
g = 0.2 # offset
|
||||||
|
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
||||||
|
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
||||||
|
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
||||||
|
|
||||||
|
elif style == 'rect4':
|
||||||
|
g = 0.5 # offset
|
||||||
|
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
||||||
|
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|
||||||
|
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
|
||||||
|
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
|
||||||
|
|
||||||
|
# Define
|
||||||
|
b, c = t[:, :2].long().T # image, class
|
||||||
gxy = t[:, 2:4] # grid xy
|
gxy = t[:, 2:4] # grid xy
|
||||||
gwh = t[:, 4:6] # grid wh
|
gwh = t[:, 4:6] # grid wh
|
||||||
gi, gj = gxy.long().t() # grid xy indices
|
gij = (gxy - offsets).long()
|
||||||
indices.append((b, a, gj, gi))
|
gi, gj = gij.T # grid xy indices
|
||||||
|
|
||||||
# Box
|
# Append
|
||||||
tbox.append(torch.cat((gxy % 1., gwh), 1)) # xywh (grids)
|
indices.append((b, a, gj, gi)) # image, anchor, grid indices
|
||||||
anch.append(anchors[a]) # anchor vec
|
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
||||||
|
anch.append(anchors[a]) # anchors
|
||||||
# Class
|
tcls.append(c) # class
|
||||||
tcls.append(c)
|
|
||||||
if c.shape[0]: # if any targets
|
if c.shape[0]: # if any targets
|
||||||
assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \
|
assert c.max() < model.nc, 'Model accepts %g classes labeled from 0-%g, however you labelled a class %g. ' \
|
||||||
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (
|
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (
|
||||||
|
|
Loading…
Reference in New Issue