This commit is contained in:
Glenn Jocher 2019-06-15 02:44:01 +02:00
parent 02291622fa
commit 995dc3ca67
2 changed files with 8 additions and 6 deletions

View File

@ -228,7 +228,7 @@ def train(
pred = model(imgs)
# Compute loss
loss, loss_items = compute_loss(pred, targets, model)
loss, loss_items = compute_loss(pred, targets, model, giou_loss=False)
if torch.isnan(loss):
print('WARNING: nan loss detected, ending training')
return results

View File

@ -270,9 +270,9 @@ def wh_iou(box1, box2):
return inter_area / union_area # iou
def compute_loss(p, targets, model): # predictions, targets, model
def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf, lgiou = ft([0]), ft([0]), ft([0]), ft([0]), ft([0])
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
h = model.hyp # hyperparameters
@ -298,9 +298,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
# lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
if giou_loss:
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
else:
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss
lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf) # obj_conf loss