updates
This commit is contained in:
parent
3c6b168a0a
commit
1191dee71b
15
train.py
15
train.py
|
@ -123,7 +123,7 @@ def train(
|
|||
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
||||
p.requires_grad = False if epoch == 0 else True
|
||||
|
||||
mloss = defaultdict(float) # mean loss
|
||||
mloss = torch.zeros(5).to(device) # mean losses
|
||||
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
||||
imgs = imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
@ -148,7 +148,7 @@ def train(
|
|||
target_list = build_targets(model, targets)
|
||||
|
||||
# Compute loss
|
||||
loss, loss_dict = compute_loss(pred, target_list)
|
||||
loss, loss_items = compute_loss(pred, target_list)
|
||||
|
||||
# Compute gradient
|
||||
if mixed_precision:
|
||||
|
@ -162,14 +162,13 @@ def train(
|
|||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Running epoch-means of tracked metrics
|
||||
for key, val in loss_dict.items():
|
||||
mloss[key] = (mloss[key] * i + val) / (i + 1)
|
||||
# Update running mean of tracked metrics
|
||||
mloss = (mloss * i + loss_items) / (i + 1)
|
||||
|
||||
# Print batch results
|
||||
s = ('%8s%12s' + '%10.3g' * 7) % (
|
||||
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, nB - 1),
|
||||
mloss['xy'], mloss['wh'], mloss['conf'], mloss['cls'],
|
||||
mloss['total'], nt, time.time() - t)
|
||||
'%g/%g' % (epoch, epochs - 1),
|
||||
'%g/%g' % (i, nB - 1), *mloss, nt, time.time() - t)
|
||||
t = time.time()
|
||||
print(s)
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import glob
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
|
@ -244,7 +243,7 @@ def wh_iou(box1, box2):
|
|||
|
||||
|
||||
def compute_loss(p, targets): # predictions, targets
|
||||
FT = torch.cuda.FloatTensor if p[0].is_cuda else torch.FloatTensor
|
||||
FT = torch.cuda.Tensor if p[0].is_cuda else torch.Tensor
|
||||
lxy, lwh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0])
|
||||
txy, twh, tcls, indices = targets
|
||||
MSE = nn.MSELoss()
|
||||
|
@ -274,13 +273,7 @@ def compute_loss(p, targets): # predictions, targets
|
|||
lconf += (k * 64) * BCE(pi0[..., 4], tconf) # obj_conf loss
|
||||
loss = lxy + lwh + lconf + lcls
|
||||
|
||||
# Add to dictionary
|
||||
d = defaultdict(float)
|
||||
losses = [loss.item(), lxy.item(), lwh.item(), lconf.item(), lcls.item()]
|
||||
for k, v in zip(['total', 'xy', 'wh', 'conf', 'cls'], losses):
|
||||
d[k] = v
|
||||
|
||||
return loss, d
|
||||
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
|
||||
|
||||
|
||||
def build_targets(model, targets):
|
||||
|
|
Loading…
Reference in New Issue