faster and more informative training plots (#1114)

* faster and more informative training plots

* Update utils.py

Looks good. Needs pep8 linting, I'll do that in PyCharm later once PR is in.

* Update test.py

* Update train.py

f for the tb descriptor lets us plot several batches, i.e. to allow us to change L292 to 'if ni < 3' for 3 examples.

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Josh Veitch-Michaelis 2020-04-30 21:37:04 +01:00 committed by GitHub
parent f1d73a29e5
commit fb1b5e09b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 144 additions and 28 deletions

12
test.py
View File

@ -82,11 +82,6 @@ def test(cfg,
nb, _, height, width = imgs.shape # batch size, channels, height, width nb, _, height, width = imgs.shape # batch size, channels, height, width
whwh = torch.Tensor([width, height, width, height]).to(device) whwh = torch.Tensor([width, height, width, height]).to(device)
# Plot images with bounding boxes
f = 'test_batch%g.jpg' % batch_i # filename
if batch_i < 1 and not os.path.exists(f):
plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)
# Disable gradients # Disable gradients
with torch.no_grad(): with torch.no_grad():
# Run model # Run model
@ -167,6 +162,13 @@ def test(cfg,
# Append statistics (correct, conf, pcls, tcls) # Append statistics (correct, conf, pcls, tcls)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
# Plot images
if batch_i < 1:
f = 'test_batch%g_gt.jpg' % batch_i # filename
plot_images(images=imgs, targets=targets, paths=paths, names=names, fname=f) # ground truth
f = 'test_batch%g_pred.jpg' % batch_i # filename
plot_images(images=imgs, targets=output_to_target(output, width, height), paths=paths, names=names, fname=f) # predictions
# Compute statistics # Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats): if len(stats):

View File

@ -292,9 +292,9 @@ def train():
# Plot # Plot
if ni < 1: if ni < 1:
f = 'train_batch%g.jpg' % i # filename f = 'train_batch%g.jpg' % i # filename
plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer: if tb_writer:
tb_writer.add_image(f, cv2.imread(f)[:, :, ::-1], dataformats='HWC') tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard # tb_writer.add_graph(model, imgs) # add model to tensorboard
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------

View File

@ -829,6 +829,35 @@ def fitness(x):
return (x[:, :4] * w).sum(1) return (x[:, :4] * w).sum(1)
def output_to_target(output, width, height):
"""
Convert a YOLO model output to target format
[batch_id, class_id, x, y, w, h, conf]
"""
if isinstance(output, torch.Tensor):
output = output.cpu().numpy()
targets = []
for i, o in enumerate(output):
if o is not None:
for pred in o:
box = pred[:4]
w = (box[2]-box[0])/width
h = (box[3]-box[1])/height
x = box[0]/width + w/2
y = box[1]/height + h/2
conf = pred[4]
cls = int(pred[5])
targets.append([i, cls, x, y, w, h, conf])
return np.array(targets)
# Plotting functions --------------------------------------------------------------------------------------------------- # Plotting functions ---------------------------------------------------------------------------------------------------
def plot_one_box(x, img, color=None, label=None, line_thickness=None): def plot_one_box(x, img, color=None, label=None, line_thickness=None):
# Plots one bounding box on image img # Plots one bounding box on image img
@ -864,30 +893,115 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
fig.savefig('comparison.png', dpi=200) fig.savefig('comparison.png', dpi=200)
def plot_images(imgs, targets, paths=None, fname='images.png'): def plot_images(images, targets, paths=None, fname='images.jpg', names=None, class_labels=True, confidence_labels=True, max_size=640, max_subplots=16):
# Plots training images overlaid with targets
imgs = imgs.cpu().numpy()
targets = targets.cpu().numpy()
# targets = targets[targets[:, 1] == 21] # plot only one class
fig = plt.figure(figsize=(10, 10)) if isinstance(images, torch.Tensor):
bs, _, h, w = imgs.shape # batch size, _, height, width images = images.cpu().numpy()
bs = min(bs, 16) # limit plot to 16 images
ns = np.ceil(bs ** 0.5) # number of subplots if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
# un-normalise
if np.max(images[0]) <= 1:
images *= 255
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
# Check if we should resize
should_resize = False
if w > max_size or h > max_size:
scale_factor = max_size/max(h, w)
h = math.ceil(scale_factor*h)
w = math.ceil(scale_factor*w)
should_resize=True
# Empty array for output
mosaic_width = int(ns*w)
mosaic_height = int(ns*h)
mosaic = 255*np.ones((mosaic_height, mosaic_width, 3), dtype=np.uint8)
# Fix class - colour map
prop_cycle = plt.rcParams['axes.prop_cycle']
# https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
hex2rgb = lambda h : tuple(int(h[1+i:1+i+2], 16) for i in (0, 2, 4))
color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
for i in range(bs): for i, image in enumerate(images):
boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T
boxes[[0, 2]] *= w # e.g. if the last batch has fewer images than we expect
boxes[[1, 3]] *= h if i == max_subplots:
plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) break
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
plt.axis('off') block_x = int(w * (i // ns))
block_y = int(h * (i % ns))
image = image.transpose(1,2,0)
if should_resize:
image = cv2.resize(image, (w, h))
mosaic[block_y:block_y+h, block_x:block_x+w,:] = image
if targets is not None:
image_targets = targets[targets[:, 0] == i]
boxes = xywh2xyxy(image_targets[:,2:6]).T
classes = image_targets[:,1].astype('int')
# Check if we have object confidences (gt vs pred)
confidences = None
if image_targets.shape[1] > 6:
confidences = image_targets[:,6]
boxes[[0, 2]] *= w
boxes[[0, 2]] += block_x
boxes[[1, 3]] *= h
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
color = color_lut[int(classes[j]) % len(color_lut)]
box = box.astype(int)
cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=2)
# Draw class label
if class_labels and max_size > 250:
label = str(classes[j]) if names is None else names[classes[j]]
if confidences is not None and confidence_labels:
label += " {:1.2f}".format(confidences[j])
font_scale = 0.4/10 * min(20, h * 0.05)
font_thickness = 2 if max(w, h) > 320 else 1
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
cv2.rectangle(mosaic, (box[0], box[1]), (box[0]+label_size[0], box[1]-label_size[1]), color, thickness=-1)
cv2.putText(mosaic, label, (box[0], box[1]), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_thickness, color=(255,255,255))
# Draw image filename labels
if paths is not None: if paths is not None:
s = Path(paths[i]).name # Trim to 40 chars
plt.title(s[:min(len(s), 40)], fontdict={'size': 8}) # limit to 40 characters label = os.path.basename(paths[i])[:40]
fig.tight_layout()
fig.savefig(fname, dpi=200) # Empirical calculation to fit label
plt.close() # 0.4 is at most (13, 10) px per char at thickness = 1
# Fit label to 20px high, or shrink if it would be too big
max_font_scale = (w/len(label))*(0.4/8)
font_scale = min(0.4 * 20/8.5, max_font_scale)
font_thickness = 1
label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness)
cv2.rectangle(mosaic, (block_x+5, block_y+label_size[1]+baseline+5), (block_x+label_size[0]+5, block_y), 0, thickness=-1)
cv2.putText(mosaic, label, (block_x+5, block_y+label_size[1]+5), cv2.FONT_HERSHEY_DUPLEX, font_scale, (255,255,255), font_thickness)
# Image border
cv2.rectangle(mosaic, (block_x, block_y), (block_x+w, block_y+h), (255,255,255), thickness=3)
if fname is not None:
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
return mosaic
def plot_test_txt(): # from utils.utils import *; plot_test() def plot_test_txt(): # from utils.utils import *; plot_test()