tight_layout=True

This commit is contained in:
Glenn Jocher 2020-05-24 10:51:35 -07:00
parent 16ea613628
commit 23f85a68b8
1 changed files with 50 additions and 18 deletions

View File

@ -5,6 +5,7 @@ import random
import shutil import shutil
import subprocess import subprocess
import time import time
from copy import copy
from pathlib import Path from pathlib import Path
from sys import platform from sys import platform
@ -370,8 +371,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
# GIoU # GIoU
pxy = torch.sigmoid(ps[:, 0:2]) pxy = ps[:, :2].sigmoid()
pwh = torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchors[i] pwh = ps[:, 2:4].exp().clamp(max=1E3) * anchors[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box pbox = torch.cat((pxy, pwh), 1) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target) giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
@ -416,7 +417,6 @@ def build_targets(p, targets, model):
style = None style = None
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i, j in enumerate(model.yolo_layers): for i, j in enumerate(model.yolo_layers):
# get number of grid points and anchor vec for this yolo layer
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
na = anchors.shape[0] # number of anchors na = anchors.shape[0] # number of anchors
@ -573,7 +573,7 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
torch.save(x, f) torch.save(x, f)
def create_backbone(f='weights/last.pt'): # from utils.utils import *; create_backbone() def create_backbone(f='weights/best.pt'): # from utils.utils import *; create_backbone()
# create a backbone from a *.pt file # create a backbone from a *.pt file
x = torch.load(f, map_location=torch.device('cpu')) x = torch.load(f, map_location=torch.device('cpu'))
x['optimizer'] = None x['optimizer'] = None
@ -816,12 +816,12 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)] color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl) cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label: if label:
tf = max(tl - 1, 1) # font thickness tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1) # filled cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
@ -928,22 +928,34 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
return mosaic return mosaic
def plot_lr_scheduler(optimizer, scheduler, epochs=300):
# Plot LR simulating training for full epochs
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
y = []
for _ in range(epochs):
scheduler.step()
y.append(optimizer.param_groups[0]['lr'])
plt.plot(y, '.-', label='LR')
plt.xlabel('epoch')
plt.ylabel('LR')
plt.tight_layout()
plt.savefig('LR.png', dpi=200)
def plot_test_txt(): # from utils.utils import *; plot_test() def plot_test_txt(): # from utils.utils import *; plot_test()
# Plot test.txt histograms # Plot test.txt histograms
x = np.loadtxt('test.txt', dtype=np.float32) x = np.loadtxt('test.txt', dtype=np.float32)
box = xyxy2xywh(x[:, :4]) box = xyxy2xywh(x[:, :4])
cx, cy = box[:, 0], box[:, 1] cx, cy = box[:, 0], box[:, 1]
fig, ax = plt.subplots(1, 1, figsize=(6, 6)) fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
ax.set_aspect('equal') ax.set_aspect('equal')
fig.tight_layout()
plt.savefig('hist2d.png', dpi=300) plt.savefig('hist2d.png', dpi=300)
fig, ax = plt.subplots(1, 2, figsize=(12, 6)) fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
ax[0].hist(cx, bins=600) ax[0].hist(cx, bins=600)
ax[1].hist(cy, bins=600) ax[1].hist(cy, bins=600)
fig.tight_layout()
plt.savefig('hist1d.png', dpi=200) plt.savefig('hist1d.png', dpi=200)
@ -951,22 +963,45 @@ def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()
# Plot targets.txt histograms # Plot targets.txt histograms
x = np.loadtxt('targets.txt', dtype=np.float32).T x = np.loadtxt('targets.txt', dtype=np.float32).T
s = ['x targets', 'y targets', 'width targets', 'height targets'] s = ['x targets', 'y targets', 'width targets', 'height targets']
fig, ax = plt.subplots(2, 2, figsize=(8, 8)) fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
ax = ax.ravel() ax = ax.ravel()
for i in range(4): for i in range(4):
ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std())) ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
ax[i].legend() ax[i].legend()
ax[i].set_title(s[i]) ax[i].set_title(s[i])
fig.tight_layout()
plt.savefig('targets.jpg', dpi=200) plt.savefig('targets.jpg', dpi=200)
def plot_labels(labels):
# plot dataset labels
c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes
def hist2d(x, y, n=100):
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
return hist[xidx, yidx]
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
ax = ax.ravel()
ax[0].hist(c, bins=int(c.max() + 1))
ax[0].set_xlabel('classes')
ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
ax[1].set_xlabel('x')
ax[1].set_ylabel('y')
ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
ax[2].set_xlabel('width')
ax[2].set_ylabel('height')
plt.savefig('labels.png', dpi=200)
def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp) def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
# Plot hyperparameter evolution results in evolve.txt # Plot hyperparameter evolution results in evolve.txt
x = np.loadtxt('evolve.txt', ndmin=2) x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x) f = fitness(x)
# weights = (f - f.min()) ** 2 # for weighted results # weights = (f - f.min()) ** 2 # for weighted results
fig = plt.figure(figsize=(12, 10)) fig = plt.figure(figsize=(12, 10), tight_layout=True)
matplotlib.rc('font', **{'size': 8}) matplotlib.rc('font', **{'size': 8})
for i, (k, v) in enumerate(hyp.items()): for i, (k, v) in enumerate(hyp.items()):
y = x[:, i + 7] y = x[:, i + 7]
@ -977,7 +1012,6 @@ def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_re
plt.plot(y, f, '.') plt.plot(y, f, '.')
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
print('%15s: %.3g' % (k, mu)) print('%15s: %.3g' % (k, mu))
fig.tight_layout()
plt.savefig('evolve.png', dpi=200) plt.savefig('evolve.png', dpi=200)
@ -989,7 +1023,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n) x = range(start, min(stop, n) if stop else n)
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5)) fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
ax = ax.ravel() ax = ax.ravel()
for i in range(5): for i in range(5):
for j in [i, i + 5]: for j in [i, i + 5]:
@ -1000,13 +1034,12 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re
ax[i].set_title(t[i]) ax[i].set_title(t[i])
ax[i].legend() ax[i].legend()
ax[i].set_ylabel(f) if i == 0 else None # add filename ax[i].set_ylabel(f) if i == 0 else None # add filename
fig.tight_layout()
fig.savefig(f.replace('.txt', '.png'), dpi=200) fig.savefig(f.replace('.txt', '.png'), dpi=200)
def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import *; plot_results() def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import *; plot_results()
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov3#training # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov3#training
fig, ax = plt.subplots(2, 5, figsize=(12, 6)) fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel() ax = ax.ravel()
s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall', s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'F1'] 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'F1']
@ -1032,6 +1065,5 @@ def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import
except: except:
print('Warning: Plotting error for %s, skipping file' % f) print('Warning: Plotting error for %s, skipping file' % f)
fig.tight_layout()
ax[1].legend() ax[1].legend()
fig.savefig('results.png', dpi=200) fig.savefig('results.png', dpi=200)