This commit is contained in:
Glenn Jocher 2019-04-15 13:55:52 +02:00
parent 3c6b168a0a
commit 1191dee71b
2 changed files with 9 additions and 17 deletions

View File

@ -123,7 +123,7 @@ def train(
if int(name.split('.')[1]) < cutoff: # if layer < 75
p.requires_grad = False if epoch == 0 else True
mloss = defaultdict(float) # mean loss
mloss = torch.zeros(5).to(device) # mean losses
for i, (imgs, targets, _, _) in enumerate(dataloader):
imgs = imgs.to(device)
targets = targets.to(device)
@ -148,7 +148,7 @@ def train(
target_list = build_targets(model, targets)
# Compute loss
loss, loss_dict = compute_loss(pred, target_list)
loss, loss_items = compute_loss(pred, target_list)
# Compute gradient
if mixed_precision:
@ -162,14 +162,13 @@ def train(
optimizer.step()
optimizer.zero_grad()
# Running epoch-means of tracked metrics
for key, val in loss_dict.items():
mloss[key] = (mloss[key] * i + val) / (i + 1)
# Update running mean of tracked metrics
mloss = (mloss * i + loss_items) / (i + 1)
# Print batch results
s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1),
mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'],
mloss['total'], nt, time.time() - t)
'%g/%g' % (epoch, epochs - 1),
'%g/%g' % (i, nB - 1), *mloss, nt, time.time() - t)
t = time.time()
print(s)

View File

@ -1,6 +1,5 @@
import glob
import random
from collections import defaultdict
import cv2
import matplotlib
@ -244,7 +243,7 @@ def wh_iou(box1, box2):
def compute_loss(p, targets): # predictions, targets
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0])
txy, twh, tcls, indices = targets
MSE = nn.MSELoss()
@ -274,13 +273,7 @@ def compute_loss(p, targets): # predictions, targets
lconf += (k * 64) * BCE(pi0[..., 4], tconf) # obj_conf loss
loss = lxy + lwh + lconf + lcls
# Add to dictionary
d = defaultdict(float)
losses = [loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item()]
for k, v in zip(['total', 'xy', 'wh', 'conf', 'cls'], losses):
d[k] = v
return loss, d
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
def build_targets(model, targets):