GIoU to default
This commit is contained in:
parent
32a52dfb02
commit
70f6379601
2
test.py
2
test.py
|
@ -64,7 +64,7 @@ def test(
|
|||
|
||||
# Plot images with bounding boxes
|
||||
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
|
||||
plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg')
|
||||
plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg')
|
||||
|
||||
# Run model
|
||||
inf_out, train_out = model(imgs) # inference and training outputs
|
||||
|
|
4
train.py
4
train.py
|
@ -178,7 +178,7 @@ def train(
|
|||
|
||||
mloss = torch.zeros(5).to(device) # mean losses
|
||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||
for i, (imgs, targets, _, _) in pbar:
|
||||
for i, (imgs, targets, paths, _) in pbar:
|
||||
imgs = imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
|
@ -192,7 +192,7 @@ def train(
|
|||
|
||||
# Plot images with bounding boxes
|
||||
if epoch == 0 and i == 0:
|
||||
plot_images(imgs=imgs, targets=targets, fname='train_batch%g.jpg' % i)
|
||||
plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i)
|
||||
|
||||
# SGD burn-in
|
||||
if epoch == 0 and i <= n_burnin:
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from . import torch_utils
|
||||
from . import google_utils
|
||||
|
@ -611,7 +612,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
|
|||
fig.savefig('comparison.png', dpi=300)
|
||||
|
||||
|
||||
def plot_images(imgs, targets, fname='images.jpg'):
|
||||
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
|
||||
# Plots training images overlaid with targets
|
||||
imgs = imgs.cpu().numpy()
|
||||
targets = targets.cpu().numpy()
|
||||
|
@ -627,6 +628,8 @@ def plot_images(imgs, targets, fname='images.jpg'):
|
|||
plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
|
||||
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
|
||||
plt.axis('off')
|
||||
if paths is not None:
|
||||
plt.title(Path(paths[i]).name, fontdict={'size': 8})
|
||||
fig.tight_layout()
|
||||
fig.savefig(fname, dpi=300)
|
||||
plt.close()
|
||||
|
|
Loading…
Reference in New Issue