diff --git a/train.py b/train.py index 56f63ade..2209e0d9 100644 --- a/train.py +++ b/train.py @@ -228,7 +228,7 @@ def train( pred = model(imgs) # Compute loss - loss, loss_items = compute_loss(pred, targets, model) + loss, loss_items = compute_loss(pred, targets, model, giou_loss=False) if torch.isnan(loss): print('WARNING: nan loss detected, ending training') return results diff --git a/utils/utils.py b/utils/utils.py index cb8c7434..fff6d03f 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -270,9 +270,9 @@ def wh_iou(box1, box2): return inter_area / union_area # iou -def compute_loss(p, targets, model): # predictions, targets, model +def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, model ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor - lxy, lwh, lcls, lconf, lgiou = ft([0]), ft([0]), ft([0]), ft([0]), ft([0]) + lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0]) txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets) h = model.hyp # hyperparameters @@ -298,9 +298,11 @@ def compute_loss(p, targets, model): # predictions, targets, model pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) - # lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss - lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss - lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss + if giou_loss: + lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss + else: + lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss + lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf) # obj_conf loss