This commit is contained in:
Glenn Jocher 2019-04-09 12:24:01 +02:00
parent 6cb3c61320
commit 26b115c306
3 changed files with 23 additions and 10 deletions

View File

@ -61,6 +61,10 @@ def test(
targets = targets.to(device)
imgs = imgs.to(device)
# Plot images with bounding boxes
if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg')
# Run model
inf_out, train_out = model(imgs) # inference and training outputs

View File

@ -104,6 +104,8 @@ def train(
model_info(model)
nB = len(dataloader)
n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None
os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None
for epoch in range(start_epoch, epochs):
model.train()
print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
@ -127,16 +129,8 @@ def train(
continue
# Plot images with bounding boxes
plot_images = False
if plot_images:
fig = plt.figure(figsize=(10, 10))
for ip in range(len(imgs)):
boxes = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy().T * img_size
plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0))
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
plt.axis('off')
fig.tight_layout()
fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi)
if epoch == 0 and i == 0:
plot_images(imgs=imgs, targets=targets, fname='train_batch0.jpg')
# SGD burn-in
if epoch == 0 and i <= n_burnin:

View File

@ -487,6 +487,21 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
fig.savefig('comparison.jpg', dpi=fig.dpi)
def plot_images(imgs, targets, fname='images.jpg'):
fig = plt.figure(figsize=(10, 10))
img_size = imgs.shape[3]
bs = imgs.shape[0] # batch size
sp = np.ceil(bs ** 0.5) # subplots
for i in range(bs):
boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).numpy().T * img_size
plt.subplot(sp, sp, i + 1).imshow(imgs[i].numpy().transpose(1, 2, 0))
plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
plt.axis('off')
fig.tight_layout()
fig.savefig(fname, dpi=fig.dpi)
def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
# Plot training results files 'results*.txt'
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')