simplify train.py

This commit is contained in:
Glenn Jocher 2018-09-19 04:21:46 +02:00
parent 68de92f1a1
commit 29fbcb059f
1 changed files with 33 additions and 40 deletions

View File

@ -90,7 +90,7 @@ def main(opt):
modelinfo(model) modelinfo(model)
t0, t1 = time.time(), time.time() t0, t1 = time.time(), time.time()
print('%10s' * 16 % ( print('%10s' * 16 % (
'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nGT', 'TP', 'FP', 'FN', 'time')) 'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nTargets', 'TP', 'FP', 'FN', 'time'))
for epoch in range(opt.epochs): for epoch in range(opt.epochs):
epoch += start_epoch epoch += start_epoch
@ -115,14 +115,10 @@ def main(opt):
metrics = torch.zeros(4, num_classes) metrics = torch.zeros(4, num_classes)
for i, (imgs, targets) in enumerate(dataloader): for i, (imgs, targets) in enumerate(dataloader):
n = opt.batch_size # number of pictures at a time if sum([len(x) for x in targets]) < 1: # if no targets continue
for j in range(int(len(imgs) / n)):
targets_j = targets[j * n:j * n + n]
nGT = sum([len(x) for x in targets_j])
if nGT < 1:
continue continue
loss = model(imgs[j * n:j * n + n].to(device), targets_j, requestPrecision=True, epoch=epoch) loss = model(imgs.to(device), targets, requestPrecision=True, epoch=epoch)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -151,20 +147,17 @@ def main(opt):
s = ('%10s%10s' + '%10.3g' * 14) % ( s = ('%10s%10s' + '%10.3g' * 14) % (
'%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], '%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
rloss['loss'], mean_precision, mean_recall, model.losses['nGT'], model.losses['TP'], rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
model.losses['FP'], model.losses['FN'], time.time() - t1) model.losses['FP'], model.losses['FN'], time.time() - t1)
t1 = time.time() t1 = time.time()
print(s) print(s)
# if i == 1:
# return
# Write epoch results # Write epoch results
with open('results.txt', 'a') as file: with open('results.txt', 'a') as file:
file.write(s + '\n') file.write(s + '\n')
# Update best loss # Update best loss
loss_per_target = rloss['loss'] / rloss['nGT'] loss_per_target = rloss['loss'] / rloss['nT']
if loss_per_target < best_loss: if loss_per_target < best_loss:
best_loss = loss_per_target best_loss = loss_per_target