From 29fbcb059f74959dff15be7f1ee033b34cf0ea53 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 19 Sep 2018 04:21:46 +0200 Subject: [PATCH] simplify train.py --- train.py | 73 +++++++++++++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/train.py b/train.py index 626bbcc0..76d96efb 100644 --- a/train.py +++ b/train.py @@ -90,7 +90,7 @@ def main(opt): modelinfo(model) t0, t1 = time.time(), time.time() print('%10s' * 16 % ( - 'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nGT', 'TP', 'FP', 'FN', 'time')) + 'Epoch', 'Batch', 'x', 'y', 'w', 'h', 'conf', 'cls', 'total', 'P', 'R', 'nTargets', 'TP', 'FP', 'FN', 'time')) for epoch in range(opt.epochs): epoch += start_epoch @@ -115,56 +115,49 @@ def main(opt): metrics = torch.zeros(4, num_classes) for i, (imgs, targets) in enumerate(dataloader): - n = opt.batch_size # number of pictures at a time - for j in range(int(len(imgs) / n)): - targets_j = targets[j * n:j * n + n] - nGT = sum([len(x) for x in targets_j]) - if nGT < 1: - continue + if sum([len(x) for x in targets]) < 1: # if no targets continue + continue - loss = model(imgs[j * n:j * n + n].to(device), targets_j, requestPrecision=True, epoch=epoch) - optimizer.zero_grad() - loss.backward() - optimizer.step() + loss = model(imgs.to(device), targets, requestPrecision=True, epoch=epoch) + optimizer.zero_grad() + loss.backward() + optimizer.step() - ui += 1 - metrics += model.losses['metrics'] - for key, val in model.losses.items(): - rloss[key] = (rloss[key] * ui + val) / (ui + 1) + ui += 1 + metrics += model.losses['metrics'] + for key, val in model.losses.items(): + rloss[key] = (rloss[key] * ui + val) / (ui + 1) - # Precision - precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16) - k = (metrics[0] + metrics[1]) > 0 - if k.sum() > 0: - mean_precision = precision[k].mean() - else: - mean_precision = 0 + # Precision + precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16) + k = (metrics[0] + metrics[1]) > 0 + if k.sum() > 0: + mean_precision = precision[k].mean() + else: + mean_precision = 0 - # Recall - recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16) - k = (metrics[0] + metrics[2]) > 0 - if k.sum() > 0: - mean_recall = recall[k].mean() - else: - mean_recall = 0 + # Recall + recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16) + k = (metrics[0] + metrics[2]) > 0 + if k.sum() > 0: + mean_recall = recall[k].mean() + else: + mean_recall = 0 - s = ('%10s%10s' + '%10.3g' * 14) % ( - '%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], - rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], - rloss['loss'], mean_precision, mean_recall, model.losses['nGT'], model.losses['TP'], - model.losses['FP'], model.losses['FN'], time.time() - t1) - t1 = time.time() - print(s) - - # if i == 1: - # return + s = ('%10s%10s' + '%10.3g' * 14) % ( + '%g/%g' % (epoch, opt.epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], + rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], + rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'], + model.losses['FP'], model.losses['FN'], time.time() - t1) + t1 = time.time() + print(s) # Write epoch results with open('results.txt', 'a') as file: file.write(s + '\n') # Update best loss - loss_per_target = rloss['loss'] / rloss['nGT'] + loss_per_target = rloss['loss'] / rloss['nT'] if loss_per_target < best_loss: best_loss = loss_per_target