GIoU to default

This commit is contained in:
glenn-jocher 2019-07-07 23:24:34 +02:00
parent 32a52dfb02
commit 70f6379601
3 changed files with 7 additions and 4 deletions

View File

@ -64,7 +64,7 @@ def test(
# Plot images with bounding boxes
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg')
plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg')
# Run model
inf_out, train_out = model(imgs) # inference and training outputs

View File

@ -178,7 +178,7 @@ def train(
mloss = torch.zeros(5).to(device) # mean losses
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
for i, (imgs, targets, _, _) in pbar:
for i, (imgs, targets, paths, _) in pbar:
imgs = imgs.to(device)
targets = targets.to(device)
@ -192,7 +192,7 @@ def train(
# Plot images with bounding boxes
if epoch == 0 and i == 0:
plot_images(imgs=imgs, targets=targets, fname='train_batch%g.jpg' % i)
plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i)
# SGD burn-in
if epoch == 0 and i <= n_burnin:

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from . import torch_utils
from . import google_utils
@ -611,7 +612,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
fig.savefig('comparison.png', dpi=300)
def plot_images(imgs, targets, fname='images.jpg'):
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
# Plots training images overlaid with targets
imgs = imgs.cpu().numpy()
targets = targets.cpu().numpy()
@ -627,6 +628,8 @@ def plot_images(imgs, targets, fname='images.jpg'):
plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
plt.axis('off')
if paths is not None:
plt.title(Path(paths[i]).name, fontdict={'size': 8})
fig.tight_layout()
fig.savefig(fname, dpi=300)
plt.close()