From 1191dee71b0ef507bf4d95d2acfcc89555fb630c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 15 Apr 2019 13:55:52 +0200 Subject: [PATCH] updates --- train.py | 15 +++++++-------- utils/utils.py | 11 ++--------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index 94e5e1e7..4dca56bc 100644 --- a/train.py +++ b/train.py @@ -123,7 +123,7 @@ def train( if int(name.split('.')[1]) < cutoff: # if layer < 75 p.requires_grad = False if epoch == 0 else True - mloss = defaultdict(float) # mean loss + mloss = torch.zeros(5).to(device) # mean losses for i, (imgs, targets, _, _) in enumerate(dataloader): imgs = imgs.to(device) targets = targets.to(device) @@ -148,7 +148,7 @@ def train( target_list = build_targets(model, targets) # Compute loss - loss, loss_dict = compute_loss(pred, target_list) + loss, loss_items = compute_loss(pred, target_list) # Compute gradient if mixed_precision: @@ -162,14 +162,13 @@ def train( optimizer.step() optimizer.zero_grad() - # Running epoch-means of tracked metrics - for key, val in loss_dict.items(): - mloss[key] = (mloss[key] * i + val) / (i + 1) + # Update running mean of tracked metrics + mloss = (mloss * i + loss_items) / (i + 1) + # Print batch results s = ('%8s%12s' + '%10.3g' * 7) % ( - '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1), - mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'], - mloss['total'], nt, time.time() - t) + '%g/%g' % (epoch, epochs - 1), + '%g/%g' % (i, nB - 1), *mloss, nt, time.time() - t) t = time.time() print(s) diff --git a/utils/utils.py b/utils/utils.py index 6b7949cf..165a0f4b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,6 +1,5 @@ import glob import random -from collections import defaultdict import cv2 import matplotlib @@ -244,7 +243,7 @@ def wh_iou(box1, box2): def compute_loss(p, targets): # predictions, targets - FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor + FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]) txy, twh, tcls, indices = targets MSE = nn.MSELoss() @@ -274,13 +273,7 @@ def compute_loss(p, targets): # predictions, targets lconf += (k * 64) * BCE(pi0[..., 4], tconf) # obj_conf loss loss = lxy + lwh + lconf + lcls - # Add to dictionary - d = defaultdict(float) - losses = [loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item()] - for k, v in zip(['total', 'xy', 'wh', 'conf', 'cls'], losses): - d[k] = v - - return loss, d + return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach() def build_targets(model, targets):