yolov5 regress updates to yolov3
This commit is contained in:
parent
c94019f159
commit
110ead20e6
154
utils/utils.py
154
utils/utils.py
|
@ -76,20 +76,6 @@ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
|||
return image_weights
|
||||
|
||||
|
||||
def coco_class_weights(): # frequency of each class in coco train2014
|
||||
n = [187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671,
|
||||
6769, 5706, 3908, 903, 3686, 3596, 6200, 7920, 8779, 4505, 4272, 1862, 4698, 1962, 4403, 6659, 2402, 2689,
|
||||
4012, 4175, 3411, 17048, 5637, 14553, 3923, 5539, 4289, 10084, 7018, 4314, 3099, 4638, 4939, 5543, 2038, 4004,
|
||||
5053, 4578, 27292, 4113, 5931, 2905, 11174, 2873, 4036, 3415, 1517, 4122, 1980, 4464, 1190, 2302, 156, 3933,
|
||||
1877, 17630, 4337, 4624, 1075, 3468, 135, 1380]
|
||||
weights = 1 / torch.Tensor(n)
|
||||
weights /= weights.sum()
|
||||
# with open('data/coco.names', 'r') as f:
|
||||
# for k, v in zip(f.read().splitlines(), n):
|
||||
# print('%20s: %g' % (k, v))
|
||||
return weights
|
||||
|
||||
|
||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
||||
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
|
@ -355,7 +341,7 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
|
|||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||
tcls, tbox, indices, anchor_vec = build_targets(p, targets, model)
|
||||
tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
|
||||
h = model.hyp # hyperparameters
|
||||
red = 'mean' # Loss reduction (sum or mean)
|
||||
|
||||
|
@ -371,33 +357,33 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
if g > 0:
|
||||
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
||||
|
||||
# Compute losses
|
||||
np, ng = 0, 0 # number grid points, targets
|
||||
# per output
|
||||
nt = 0 # targets
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
||||
np += tobj.numel()
|
||||
|
||||
# Compute losses
|
||||
nb = len(b)
|
||||
if nb: # number of targets
|
||||
ng += nb
|
||||
nb = b.shape[0] # number of targets
|
||||
if nb:
|
||||
nt += nb
|
||||
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
||||
# ps[:, 2:4] = torch.sigmoid(ps[:, 2:4]) # wh power loss (uncomment)
|
||||
|
||||
# GIoU
|
||||
pxy = torch.sigmoid(ps[:, 0:2]) # pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)
|
||||
pwh = torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchor_vec[i]
|
||||
pxy = torch.sigmoid(ps[:, 0:2])
|
||||
pwh = torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchors[i]
|
||||
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
|
||||
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
|
||||
|
||||
# Obj
|
||||
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
|
||||
|
||||
# Class
|
||||
if model.nc > 1: # cls loss (only if multiple classes)
|
||||
t = torch.full_like(ps[:, 5:], cn) # targets
|
||||
t[range(nb), tcls[i]] = cp
|
||||
lcls += BCEcls(ps[:, 5:], t) # BCE
|
||||
# lcls += CE(ps[:, 5:], tcls[i]) # CE
|
||||
|
||||
# Append targets to text file
|
||||
# with open('targets.txt', 'a') as file:
|
||||
|
@ -410,26 +396,24 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
lcls *= h['cls']
|
||||
if red == 'sum':
|
||||
bs = tobj.shape[0] # batch size
|
||||
lobj *= 3 / (6300 * bs) * 2 # 3 / np * 2
|
||||
if ng:
|
||||
lcls *= 3 / ng / model.nc
|
||||
lbox *= 3 / ng
|
||||
g = 3.0 # loss gain
|
||||
lobj *= g / bs
|
||||
if nt:
|
||||
lcls *= g / nt / model.nc
|
||||
lbox *= g / nt
|
||||
|
||||
loss = lbox + lobj + lcls
|
||||
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
|
||||
|
||||
|
||||
def build_targets(p, targets, model):
|
||||
# targets = [image, class, x, y, w, h]
|
||||
|
||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||
nt = targets.shape[0]
|
||||
tcls, tbox, indices, av = [], [], [], []
|
||||
tcls, tbox, indices, anch = [], [], [], []
|
||||
reject, use_all_anchors = True, True
|
||||
gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
|
||||
|
||||
# m = list(model.modules())[-1]
|
||||
# for i in range(m.nl):
|
||||
# anchors = m.anchors[i]
|
||||
|
||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
for i, j in enumerate(model.yolo_layers):
|
||||
# get number of grid points and anchor vec for this yolo layer
|
||||
|
@ -455,16 +439,15 @@ def build_targets(p, targets, model):
|
|||
t, a = t[j], a[j]
|
||||
|
||||
# Indices
|
||||
b, c = t[:, :2].long().t() # target image, class
|
||||
gxy = t[:, 2:4] # grid x, y
|
||||
gwh = t[:, 4:6] # grid w, h
|
||||
gi, gj = gxy.long().t() # grid x, y indices
|
||||
b, c = t[:, :2].long().t() # image, class
|
||||
gxy = t[:, 2:4] # grid xy
|
||||
gwh = t[:, 4:6] # grid wh
|
||||
gi, gj = gxy.long().t() # grid xy indices
|
||||
indices.append((b, a, gj, gi))
|
||||
|
||||
# Box
|
||||
gxy -= gxy.floor() # xy
|
||||
tbox.append(torch.cat((gxy, gwh), 1)) # xywh (grids)
|
||||
av.append(anchors[a]) # anchor vec
|
||||
tbox.append(torch.cat((gxy % 1., gwh), 1)) # xywh (grids)
|
||||
anch.append(anchors[a]) # anchor vec
|
||||
|
||||
# Class
|
||||
tcls.append(c)
|
||||
|
@ -473,7 +456,7 @@ def build_targets(p, targets, model):
|
|||
'See https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data' % (
|
||||
model.nc, model.nc - 1, c.max())
|
||||
|
||||
return tcls, tbox, indices, av
|
||||
return tcls, tbox, indices, anch
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
||||
|
@ -486,17 +469,14 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
# Box constraints
|
||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||
|
||||
method = 'merge'
|
||||
nc = prediction[0].shape[1] - 5 # number of classes
|
||||
multi_label &= nc > 1 # multiple labels per box
|
||||
output = [None] * len(prediction)
|
||||
|
||||
merge = True # merge for best mAP
|
||||
output = [None] * prediction.shape[0]
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply conf constraint
|
||||
x = x[x[:, 4] > conf_thres]
|
||||
|
||||
# Apply width-height constraint
|
||||
x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]
|
||||
# Apply constraints
|
||||
x = x[x[:, 4] > conf_thres] # confidence
|
||||
# x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
|
@ -521,8 +501,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
if not torch.isfinite(x).all():
|
||||
x = x[torch.isfinite(x).all(1)]
|
||||
# if not torch.isfinite(x).all():
|
||||
# x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# If none remain process next image
|
||||
n = x.shape[0] # number of boxes
|
||||
|
@ -530,28 +510,21 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
continue
|
||||
|
||||
# Sort by confidence
|
||||
# if method == 'fast_batch':
|
||||
# x = x[x[:, 4].argsort(descending=True)]
|
||||
# x = x[x[:, 4].argsort(descending=True)]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
|
||||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
if 1 < n < 3E3: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
try:
|
||||
# weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
|
||||
# weights /= weights.sum(0) # normalize
|
||||
# x[:, :4] = torch.mm(weights.T, x[:, :4])
|
||||
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
||||
pass
|
||||
elif method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
||||
i = iou.max(0)[0] < iou_thres
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
weights = iou * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
# i = i[iou.sum(1) > 1] # require redundancy
|
||||
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
||||
print(x, i, x.shape, i.shape)
|
||||
pass
|
||||
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
|
@ -621,13 +594,6 @@ def coco_only_people(path='../coco/labels/train2017/'): # from utils.utils impo
|
|||
print(labels.shape[0], file)
|
||||
|
||||
|
||||
def select_best_evolve(path='evolve*.txt'): # from utils.utils import *; select_best_evolve()
|
||||
# Find best evolved mutation
|
||||
for file in sorted(glob.glob(path)):
|
||||
x = np.loadtxt(file, dtype=np.float32, ndmin=2)
|
||||
print(file, x[fitness(x).argmax()])
|
||||
|
||||
|
||||
def crop_images_random(path='../images/', scale=0.50): # from utils.utils import *; crop_images_random()
|
||||
# crops images into random squares up to scale fraction
|
||||
# WARNING: overwrites images!
|
||||
|
@ -708,17 +674,12 @@ def kmean_anchors(path='./data/coco64.txt', n=9, img_size=(320, 1024), thr=0.20,
|
|||
wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1)) # normalized to pixels (multi-scale)
|
||||
wh = wh[(wh > 2.0).all(1)] # remove below threshold boxes (< 2 pixels wh)
|
||||
|
||||
# Darknet yolov3.cfg anchors
|
||||
use_darknet = False
|
||||
if use_darknet and n == 9:
|
||||
k = np.array([[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]])
|
||||
else:
|
||||
# Kmeans calculation
|
||||
from scipy.cluster.vq import kmeans
|
||||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
|
||||
s = wh.std(0) # sigmas for whitening
|
||||
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
|
||||
k *= s
|
||||
# Kmeans calculation
|
||||
from scipy.cluster.vq import kmeans
|
||||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
|
||||
s = wh.std(0) # sigmas for whitening
|
||||
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
|
||||
k *= s
|
||||
wh = torch.Tensor(wh)
|
||||
k = print_results(k)
|
||||
|
||||
|
@ -741,7 +702,7 @@ def kmean_anchors(path='./data/coco64.txt', n=9, img_size=(320, 1024), thr=0.20,
|
|||
for _ in tqdm(range(gen), desc='Evolving anchors'):
|
||||
v = np.ones(sh)
|
||||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
||||
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6
|
||||
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
|
||||
kg = (k.copy() * v).clip(min=2.0)
|
||||
fg = fitness(kg)
|
||||
if fg > f:
|
||||
|
@ -815,17 +776,13 @@ def fitness(x):
|
|||
def output_to_target(output, width, height):
|
||||
"""
|
||||
Convert a YOLO model output to target format
|
||||
|
||||
[batch_id, class_id, x, y, w, h, conf]
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu().numpy()
|
||||
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
|
||||
if o is not None:
|
||||
for pred in o:
|
||||
box = pred[:4]
|
||||
|
@ -951,6 +908,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|||
cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
|
||||
|
||||
if fname is not None:
|
||||
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
|
||||
|
||||
return mosaic
|
||||
|
@ -993,7 +951,7 @@ def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_re
|
|||
# Plot hyperparameter evolution results in evolve.txt
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
f = fitness(x)
|
||||
weights = (f - f.min()) ** 2 # for weighted results
|
||||
# weights = (f - f.min()) ** 2 # for weighted results
|
||||
fig = plt.figure(figsize=(12, 10))
|
||||
matplotlib.rc('font', **{'size': 8})
|
||||
for i, (k, v) in enumerate(hyp.items()):
|
||||
|
@ -1055,8 +1013,8 @@ def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import
|
|||
# y /= y[0] # normalize
|
||||
ax[i].plot(x, y, marker='.', label=Path(f).stem, linewidth=2, markersize=8)
|
||||
ax[i].set_title(s[i])
|
||||
if i in [5, 6, 7]: # share train and val loss y axes
|
||||
ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
# if i in [5, 6, 7]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except:
|
||||
print('Warning: Plotting error for %s, skipping file' % f)
|
||||
|
||||
|
|
Loading…
Reference in New Issue