diff --git a/test.py b/test.py index f6408699..e2282c86 100644 --- a/test.py +++ b/test.py @@ -64,7 +64,7 @@ def test( # Plot images with bounding boxes if batch_i == 0 and not os.path.exists('test_batch0.jpg'): - plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg') + plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg') # Run model inf_out, train_out = model(imgs) # inference and training outputs diff --git a/train.py b/train.py index 48dc60a0..685b6317 100644 --- a/train.py +++ b/train.py @@ -178,7 +178,7 @@ def train( mloss = torch.zeros(5).to(device) # mean losses pbar = tqdm(enumerate(dataloader), total=nb) # progress bar - for i, (imgs, targets, _, _) in pbar: + for i, (imgs, targets, paths, _) in pbar: imgs = imgs.to(device) targets = targets.to(device) @@ -192,7 +192,7 @@ def train( # Plot images with bounding boxes if epoch == 0 and i == 0: - plot_images(imgs=imgs, targets=targets, fname='train_batch%g.jpg' % i) + plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i) # SGD burn-in if epoch == 0 and i <= n_burnin: diff --git a/utils/utils.py b/utils/utils.py index 767fd21e..767b33e4 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from PIL import Image from tqdm import tqdm +from pathlib import Path from . import torch_utils from . import google_utils @@ -611,7 +612,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() fig.savefig('comparison.png', dpi=300) -def plot_images(imgs, targets, fname='images.jpg'): +def plot_images(imgs, targets, paths=None, fname='images.jpg'): # Plots training images overlaid with targets imgs = imgs.cpu().numpy() targets = targets.cpu().numpy() @@ -627,6 +628,8 @@ def plot_images(imgs, targets, fname='images.jpg'): plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') plt.axis('off') + if paths is not None: + plt.title(Path(paths[i]).name, fontdict={'size': 8}) fig.tight_layout() fig.savefig(fname, dpi=300) plt.close()