GIoU to default
This commit is contained in:
parent
32a52dfb02
commit
70f6379601
2
test.py
2
test.py
|
@ -64,7 +64,7 @@ def test(
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
# Plot images with bounding boxes
|
||||||
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
|
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
|
||||||
plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg')
|
plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg')
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
inf_out, train_out = model(imgs) # inference and training outputs
|
inf_out, train_out = model(imgs) # inference and training outputs
|
||||||
|
|
4
train.py
4
train.py
|
@ -178,7 +178,7 @@ def train(
|
||||||
|
|
||||||
mloss = torch.zeros(5).to(device) # mean losses
|
mloss = torch.zeros(5).to(device) # mean losses
|
||||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||||
for i, (imgs, targets, _, _) in pbar:
|
for i, (imgs, targets, paths, _) in pbar:
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ def train(
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
# Plot images with bounding boxes
|
||||||
if epoch == 0 and i == 0:
|
if epoch == 0 and i == 0:
|
||||||
plot_images(imgs=imgs, targets=targets, fname='train_batch%g.jpg' % i)
|
plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i)
|
||||||
|
|
||||||
# SGD burn-in
|
# SGD burn-in
|
||||||
if epoch == 0 and i <= n_burnin:
|
if epoch == 0 and i <= n_burnin:
|
||||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from . import torch_utils
|
from . import torch_utils
|
||||||
from . import google_utils
|
from . import google_utils
|
||||||
|
@ -611,7 +612,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
|
||||||
fig.savefig('comparison.png', dpi=300)
|
fig.savefig('comparison.png', dpi=300)
|
||||||
|
|
||||||
|
|
||||||
def plot_images(imgs, targets, fname='images.jpg'):
|
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
|
||||||
# Plots training images overlaid with targets
|
# Plots training images overlaid with targets
|
||||||
imgs = imgs.cpu().numpy()
|
imgs = imgs.cpu().numpy()
|
||||||
targets = targets.cpu().numpy()
|
targets = targets.cpu().numpy()
|
||||||
|
@ -627,6 +628,8 @@ def plot_images(imgs, targets, fname='images.jpg'):
|
||||||
plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
|
plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
|
||||||
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
|
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
|
||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
|
if paths is not None:
|
||||||
|
plt.title(Path(paths[i]).name, fontdict={'size': 8})
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
fig.savefig(fname, dpi=300)
|
fig.savefig(fname, dpi=300)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
Loading…
Reference in New Issue