From 26b115c306d31563b66333ad1f6253e5ec5f687c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 9 Apr 2019 12:24:01 +0200 Subject: [PATCH] updates --- test.py | 4 ++++ train.py | 14 ++++---------- utils/utils.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/test.py b/test.py index 30c6f152..98a53e34 100644 --- a/test.py +++ b/test.py @@ -61,6 +61,10 @@ def test( targets = targets.to(device) imgs = imgs.to(device) + # Plot images with bounding boxes + if batch_i == 0 and not os.path.exists('test_batch0.jpg'): + plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg') + # Run model inf_out, train_out = model(imgs) # inference and training outputs diff --git a/train.py b/train.py index 50b418cf..58792870 100644 --- a/train.py +++ b/train.py @@ -104,6 +104,8 @@ def train( model_info(model) nB = len(dataloader) n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches + os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None + os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None for epoch in range(start_epoch, epochs): model.train() print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) @@ -127,16 +129,8 @@ def train( continue # Plot images with bounding boxes - plot_images = False - if plot_images: - fig = plt.figure(figsize=(10, 10)) - for ip in range(len(imgs)): - boxes = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy().T * img_size - plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0)) - plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') - plt.axis('off') - fig.tight_layout() - fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi) + if epoch == 0 and i == 0: + plot_images(imgs=imgs, targets=targets, fname='train_batch0.jpg') # SGD burn-in if epoch == 0 and i <= n_burnin: diff --git a/utils/utils.py b/utils/utils.py index 7b7e5108..b55680a2 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -487,6 +487,21 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() fig.savefig('comparison.jpg', dpi=fig.dpi) +def plot_images(imgs, targets, fname='images.jpg'): + fig = plt.figure(figsize=(10, 10)) + img_size = imgs.shape[3] + bs = imgs.shape[0] # batch size + sp = np.ceil(bs ** 0.5) # subplots + + for i in range(bs): + boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).numpy().T * img_size + plt.subplot(sp, sp, i + 1).imshow(imgs[i].numpy().transpose(1, 2, 0)) + plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') + plt.axis('off') + fig.tight_layout() + fig.savefig(fname, dpi=fig.dpi) + + def plot_results(start=0, stop=0): # from utils.utils import *; plot_results() # Plot training results files 'results*.txt' # import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')