diff --git a/test.py b/test.py index 96248f72..a5cec3b4 100644 --- a/test.py +++ b/test.py @@ -76,8 +76,9 @@ def test(cfg, _, _, height, width = imgs.shape # batch size, channels, height, width # Plot images with bounding boxes - if batch_i == 0 and not os.path.exists('test_batch0.jpg'): - plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg') + if batch_i == 0 and not os.path.exists('test_batch0.png'): + plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.png') + # Disable gradients with torch.no_grad(): diff --git a/train.py b/train.py index 8f09bae3..15a6d2ad 100644 --- a/train.py +++ b/train.py @@ -73,7 +73,7 @@ def train(): nc = 1 if opt.single_cls else int(data_dict['classes']) # number of classes # Remove previous results - for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): + for f in glob.glob('*_batch*.png') + glob.glob(results_file): os.remove(f) # Initialize model @@ -255,7 +255,7 @@ def train(): # Plot images with bounding boxes if ni == 0: - fname = 'train_batch%g.jpg' % i + fname = 'train_batch%g.png' % i plot_images(imgs=imgs, targets=targets, paths=paths, fname=fname) if tb_writer: tb_writer.add_image(fname, cv2.imread(fname)[:, :, ::-1], dataformats='HWC') diff --git a/utils/utils.py b/utils/utils.py index a0a29815..82b80cde 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -911,7 +911,8 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() fig.savefig('comparison.png', dpi=200) -def plot_images(imgs, targets, paths=None, fname='images.jpg'): + +def plot_images(imgs, targets, paths=None, fname='images.png'): # Plots training images overlaid with targets imgs = imgs.cpu().numpy() targets = targets.cpu().numpy() @@ -947,13 +948,13 @@ def plot_test_txt(): # from utils.utils import *; plot_test() ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) ax.set_aspect('equal') fig.tight_layout() - plt.savefig('hist2d.jpg', dpi=300) + plt.savefig('hist2d.png', dpi=300) fig, ax = plt.subplots(1, 2, figsize=(12, 6)) ax[0].hist(cx, bins=600) ax[1].hist(cy, bins=600) fig.tight_layout() - plt.savefig('hist1d.jpg', dpi=200) + plt.savefig('hist1d.png', dpi=200) def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()