diff --git a/test.py b/test.py index e1a80967..d7cc4294 100644 --- a/test.py +++ b/test.py @@ -25,7 +25,7 @@ def test(cfg, verbose = opt.task == 'test' # Remove previous - for f in glob.glob('test_batch*.jpg'): + for f in glob.glob('test_batch*.png'): os.remove(f) # Initialize model @@ -76,9 +76,9 @@ def test(cfg, _, _, height, width = imgs.shape # batch size, channels, height, width # Plot images with bounding boxes - if batch_i == 0 and not os.path.exists('test_batch0.png'): - plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.png') - + f = 'test_batch%g.png' % batch_i # filename + if batch_i < 1 and not os.path.exists(f): + plot_images(imgs=imgs, targets=targets, paths=paths, fname=f) # Disable gradients with torch.no_grad():