diff --git a/train.py b/train.py index b13e8e75..c9830b85 100644 --- a/train.py +++ b/train.py @@ -228,9 +228,9 @@ def train(cfg, # Plot images with bounding boxes if epoch == 0 and i == 0: fname = 'train_batch%g.jpg' % i - fig_data = plot_images(imgs=imgs, targets=targets, paths=paths, fname=fname) + plot_images(imgs=imgs, targets=targets, paths=paths, fname=fname) if tb_writer: - tb_writer.add_image(fname, fig_data, dataformats='HWC') + tb_writer.add_image(fname, cv2.imread(fname)[:, :, ::-1], dataformats='HWC') # Hyperparameter burn-in # n_burn = nb - 1 # min(nb // 5 + 1, 1000) # number of burn-in batches diff --git a/utils/utils.py b/utils/utils.py index 24c15122..1af93d49 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -697,9 +697,7 @@ def plot_images(imgs, targets, paths=None, fname='images.jpg'): plt.title(s[:min(len(s), 40)], fontdict={'size': 8}) # limit to 40 characters fig.tight_layout() fig.savefig(fname, dpi=200) - fig_image = fig_to_data(fig) plt.close() - return fig_image def plot_test_txt(): # from utils.utils import *; plot_test() @@ -816,9 +814,3 @@ def plot_results_orig(start=0, stop=0): # from utils.utils import *; plot_resul def version_to_tuple(version): # Used to compare versions of library return tuple(map(int, (version.split(".")))) - - -def fig_to_data(fig): - # Converts a matplotlib fig to 3D numpy array (fig is a matplotlib figure) - fig.canvas.draw() - return np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3] # RGB image