From fb1b5e09b2452edbc57d629183527eeadb7851be Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Thu, 30 Apr 2020 21:37:04 +0100 Subject: [PATCH] faster and more informative training plots (#1114) * faster and more informative training plots * Update utils.py Looks good. Needs pep8 linting, I'll do that in PyCharm later once PR is in. * Update test.py * Update train.py f for the tb descriptor lets us plot several batches, i.e. to allow us to change L292 to 'if ni < 3' for 3 examples. Co-authored-by: Glenn Jocher --- test.py | 12 ++-- train.py | 4 +- utils/utils.py | 156 ++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 144 insertions(+), 28 deletions(-) diff --git a/test.py b/test.py index b79f1d1a..0dfd2149 100644 --- a/test.py +++ b/test.py @@ -82,11 +82,6 @@ def test(cfg, nb, _, height, width = imgs.shape # batch size, channels, height, width whwh = torch.Tensor([width, height, width, height]).to(device) - # Plot images with bounding boxes - f = 'test_batch%g.jpg' % batch_i # filename - if batch_i < 1 and not os.path.exists(f): - plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) - # Disable gradients with torch.no_grad(): # Run model @@ -167,6 +162,13 @@ def test(cfg, # Append statistics (correct, conf, pcls, tcls) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) + # Plot images + if batch_i < 1: + f = 'test_batch%g_gt.jpg' % batch_i # filename + plot_images(images=imgs, targets=targets, paths=paths, names=names, fname=f) # ground truth + f = 'test_batch%g_pred.jpg' % batch_i # filename + plot_images(images=imgs, targets=output_to_target(output, width, height), paths=paths, names=names, fname=f) # predictions + # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats): diff --git a/train.py b/train.py index 67c20aa3..5ee55f88 100644 --- a/train.py +++ b/train.py @@ -292,9 +292,9 @@ def train(): # Plot if ni < 1: f = 'train_batch%g.jpg' % i # filename - plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) + res = plot_images(images=imgs, targets=targets, paths=paths, fname=f) if tb_writer: - tb_writer.add_image(f, cv2.imread(f)[:, :, ::-1], dataformats='HWC') + tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch) # tb_writer.add_graph(model, imgs) # add model to tensorboard # end batch ------------------------------------------------------------------------------------------------ diff --git a/utils/utils.py b/utils/utils.py index 4547de31..1335e46b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -829,6 +829,35 @@ def fitness(x): return (x[:, :4] * w).sum(1) +def output_to_target(output, width, height): + """ + Convert a YOLO model output to target format + + [batch_id, class_id, x, y, w, h, conf] + + """ + + if isinstance(output, torch.Tensor): + output = output.cpu().numpy() + + targets = [] + for i, o in enumerate(output): + + if o is not None: + for pred in o: + box = pred[:4] + w = (box[2]-box[0])/width + h = (box[3]-box[1])/height + x = box[0]/width + w/2 + y = box[1]/height + h/2 + conf = pred[4] + cls = int(pred[5]) + + targets.append([i, cls, x, y, w, h, conf]) + + return np.array(targets) + + # Plotting functions --------------------------------------------------------------------------------------------------- def plot_one_box(x, img, color=None, label=None, line_thickness=None): # Plots one bounding box on image img @@ -864,30 +893,115 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() fig.savefig('comparison.png', dpi=200) -def plot_images(imgs, targets, paths=None, fname='images.png'): - # Plots training images overlaid with targets - imgs = imgs.cpu().numpy() - targets = targets.cpu().numpy() - # targets = targets[targets[:, 1] == 21] # plot only one class +def plot_images(images, targets, paths=None, fname='images.jpg', names=None, class_labels=True, confidence_labels=True, max_size=640, max_subplots=16): - fig = plt.figure(figsize=(10, 10)) - bs, _, h, w = imgs.shape # batch size, _, height, width - bs = min(bs, 16) # limit plot to 16 images - ns = np.ceil(bs ** 0.5) # number of subplots + if isinstance(images, torch.Tensor): + images = images.cpu().numpy() + + if isinstance(targets, torch.Tensor): + targets = targets.cpu().numpy() + + # un-normalise + if np.max(images[0]) <= 1: + images *= 255 + + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs ** 0.5) # number of subplots (square) + + # Check if we should resize + should_resize = False + if w > max_size or h > max_size: + scale_factor = max_size/max(h, w) + h = math.ceil(scale_factor*h) + w = math.ceil(scale_factor*w) + should_resize=True + + # Empty array for output + mosaic_width = int(ns*w) + mosaic_height = int(ns*h) + mosaic = 255*np.ones((mosaic_height, mosaic_width, 3), dtype=np.uint8) + + # Fix class - colour map + prop_cycle = plt.rcParams['axes.prop_cycle'] + # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb + hex2rgb = lambda h : tuple(int(h[1+i:1+i+2], 16) for i in (0, 2, 4)) + color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']] - for i in range(bs): - boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T - boxes[[0, 2]] *= w - boxes[[1, 3]] *= h - plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) - plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') - plt.axis('off') + for i, image in enumerate(images): + + # e.g. if the last batch has fewer images than we expect + if i == max_subplots: + break + + block_x = int(w * (i // ns)) + block_y = int(h * (i % ns)) + + image = image.transpose(1,2,0) + + if should_resize: + image = cv2.resize(image, (w, h)) + + mosaic[block_y:block_y+h, block_x:block_x+w,:] = image + + if targets is not None: + image_targets = targets[targets[:, 0] == i] + boxes = xywh2xyxy(image_targets[:,2:6]).T + classes = image_targets[:,1].astype('int') + + # Check if we have object confidences (gt vs pred) + confidences = None + if image_targets.shape[1] > 6: + confidences = image_targets[:,6] + + boxes[[0, 2]] *= w + boxes[[0, 2]] += block_x + + boxes[[1, 3]] *= h + boxes[[1, 3]] += block_y + + for j, box in enumerate(boxes.T): + color = color_lut[int(classes[j]) % len(color_lut)] + box = box.astype(int) + cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=2) + + # Draw class label + if class_labels and max_size > 250: + label = str(classes[j]) if names is None else names[classes[j]] + if confidences is not None and confidence_labels: + label += " {:1.2f}".format(confidences[j]) + + font_scale = 0.4/10 * min(20, h * 0.05) + font_thickness = 2 if max(w, h) > 320 else 1 + + label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) + cv2.rectangle(mosaic, (box[0], box[1]), (box[0]+label_size[0], box[1]-label_size[1]), color, thickness=-1) + cv2.putText(mosaic, label, (box[0], box[1]), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_thickness, color=(255,255,255)) + + # Draw image filename labels if paths is not None: - s = Path(paths[i]).name - plt.title(s[:min(len(s), 40)], fontdict={'size': 8}) # limit to 40 characters - fig.tight_layout() - fig.savefig(fname, dpi=200) - plt.close() + # Trim to 40 chars + label = os.path.basename(paths[i])[:40] + + # Empirical calculation to fit label + # 0.4 is at most (13, 10) px per char at thickness = 1 + # Fit label to 20px high, or shrink if it would be too big + max_font_scale = (w/len(label))*(0.4/8) + font_scale = min(0.4 * 20/8.5, max_font_scale) + font_thickness = 1 + + label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness) + + cv2.rectangle(mosaic, (block_x+5, block_y+label_size[1]+baseline+5), (block_x+label_size[0]+5, block_y), 0, thickness=-1) + cv2.putText(mosaic, label, (block_x+5, block_y+label_size[1]+5), cv2.FONT_HERSHEY_DUPLEX, font_scale, (255,255,255), font_thickness) + + # Image border + cv2.rectangle(mosaic, (block_x, block_y), (block_x+w, block_y+h), (255,255,255), thickness=3) + + if fname is not None: + cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) + + return mosaic def plot_test_txt(): # from utils.utils import *; plot_test()