tight_layout=True
This commit is contained in:
parent
16ea613628
commit
23f85a68b8
|
@ -5,6 +5,7 @@ import random
|
|||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
|
||||
|
@ -370,8 +371,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
||||
|
||||
# GIoU
|
||||
pxy = torch.sigmoid(ps[:, 0:2])
|
||||
pwh = torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchors[i]
|
||||
pxy = ps[:, :2].sigmoid()
|
||||
pwh = ps[:, 2:4].exp().clamp(max=1E3) * anchors[i]
|
||||
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
|
||||
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
|
||||
|
@ -416,7 +417,6 @@ def build_targets(p, targets, model):
|
|||
style = None
|
||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
for i, j in enumerate(model.yolo_layers):
|
||||
# get number of grid points and anchor vec for this yolo layer
|
||||
anchors = model.module.module_list[j].anchor_vec if multi_gpu else model.module_list[j].anchor_vec
|
||||
gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
||||
na = anchors.shape[0] # number of anchors
|
||||
|
@ -573,7 +573,7 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
|
|||
torch.save(x, f)
|
||||
|
||||
|
||||
def create_backbone(f='weights/last.pt'): # from utils.utils import *; create_backbone()
|
||||
def create_backbone(f='weights/best.pt'): # from utils.utils import *; create_backbone()
|
||||
# create a backbone from a *.pt file
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
x['optimizer'] = None
|
||||
|
@ -816,12 +816,12 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
|
|||
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
color = color or [random.randint(0, 255) for _ in range(3)]
|
||||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
||||
cv2.rectangle(img, c1, c2, color, thickness=tl)
|
||||
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||||
if label:
|
||||
tf = max(tl - 1, 1) # font thickness
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img, c1, c2, color, -1) # filled
|
||||
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
|
||||
|
@ -928,22 +928,34 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|||
return mosaic
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=300):
|
||||
# Plot LR simulating training for full epochs
|
||||
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
|
||||
y = []
|
||||
for _ in range(epochs):
|
||||
scheduler.step()
|
||||
y.append(optimizer.param_groups[0]['lr'])
|
||||
plt.plot(y, '.-', label='LR')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('LR')
|
||||
plt.tight_layout()
|
||||
plt.savefig('LR.png', dpi=200)
|
||||
|
||||
|
||||
def plot_test_txt(): # from utils.utils import *; plot_test()
|
||||
# Plot test.txt histograms
|
||||
x = np.loadtxt('test.txt', dtype=np.float32)
|
||||
box = xyxy2xywh(x[:, :4])
|
||||
cx, cy = box[:, 0], box[:, 1]
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
||||
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
|
||||
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
|
||||
ax.set_aspect('equal')
|
||||
fig.tight_layout()
|
||||
plt.savefig('hist2d.png', dpi=300)
|
||||
|
||||
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
|
||||
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
|
||||
ax[0].hist(cx, bins=600)
|
||||
ax[1].hist(cy, bins=600)
|
||||
fig.tight_layout()
|
||||
plt.savefig('hist1d.png', dpi=200)
|
||||
|
||||
|
||||
|
@ -951,22 +963,45 @@ def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()
|
|||
# Plot targets.txt histograms
|
||||
x = np.loadtxt('targets.txt', dtype=np.float32).T
|
||||
s = ['x targets', 'y targets', 'width targets', 'height targets']
|
||||
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
|
||||
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
for i in range(4):
|
||||
ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
|
||||
ax[i].legend()
|
||||
ax[i].set_title(s[i])
|
||||
fig.tight_layout()
|
||||
plt.savefig('targets.jpg', dpi=200)
|
||||
|
||||
|
||||
def plot_labels(labels):
|
||||
# plot dataset labels
|
||||
c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes
|
||||
|
||||
def hist2d(x, y, n=100):
|
||||
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
||||
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
||||
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
||||
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
||||
return hist[xidx, yidx]
|
||||
|
||||
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
ax[0].hist(c, bins=int(c.max() + 1))
|
||||
ax[0].set_xlabel('classes')
|
||||
ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
|
||||
ax[1].set_xlabel('x')
|
||||
ax[1].set_ylabel('y')
|
||||
ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
|
||||
ax[2].set_xlabel('width')
|
||||
ax[2].set_ylabel('height')
|
||||
plt.savefig('labels.png', dpi=200)
|
||||
|
||||
|
||||
def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
|
||||
# Plot hyperparameter evolution results in evolve.txt
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
f = fitness(x)
|
||||
# weights = (f - f.min()) ** 2 # for weighted results
|
||||
fig = plt.figure(figsize=(12, 10))
|
||||
fig = plt.figure(figsize=(12, 10), tight_layout=True)
|
||||
matplotlib.rc('font', **{'size': 8})
|
||||
for i, (k, v) in enumerate(hyp.items()):
|
||||
y = x[:, i + 7]
|
||||
|
@ -977,7 +1012,6 @@ def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_re
|
|||
plt.plot(y, f, '.')
|
||||
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
|
||||
print('%15s: %.3g' % (k, mu))
|
||||
fig.tight_layout()
|
||||
plt.savefig('evolve.png', dpi=200)
|
||||
|
||||
|
||||
|
@ -989,7 +1023,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re
|
|||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
|
||||
n = results.shape[1] # number of rows
|
||||
x = range(start, min(stop, n) if stop else n)
|
||||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5))
|
||||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
for i in range(5):
|
||||
for j in [i, i + 5]:
|
||||
|
@ -1000,13 +1034,12 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re
|
|||
ax[i].set_title(t[i])
|
||||
ax[i].legend()
|
||||
ax[i].set_ylabel(f) if i == 0 else None # add filename
|
||||
fig.tight_layout()
|
||||
fig.savefig(f.replace('.txt', '.png'), dpi=200)
|
||||
|
||||
|
||||
def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import *; plot_results()
|
||||
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov3#training
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6))
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
|
||||
'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'F1']
|
||||
|
@ -1032,6 +1065,5 @@ def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import
|
|||
except:
|
||||
print('Warning: Plotting error for %s, skipping file' % f)
|
||||
|
||||
fig.tight_layout()
|
||||
ax[1].legend()
|
||||
fig.savefig('results.png', dpi=200)
|
||||
|
|
Loading…
Reference in New Issue