From ce0b41467739486f840a66967326f8456a9686ff Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Aug 2019 02:04:49 +0200 Subject: [PATCH] updates --- train.py | 3 +-- utils/utils.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index ceea5782..cea491c3 100644 --- a/train.py +++ b/train.py @@ -266,7 +266,7 @@ def train(cfg, pred = model(imgs) # Compute loss - loss, loss_items = compute_loss(pred, targets, model, giou_loss=not opt.xywh) + loss, loss_items = compute_loss(pred, targets, model) if torch.isnan(loss): print('WARNING: nan loss detected, ending training') return results @@ -368,7 +368,6 @@ if __name__ == '__main__': parser.add_argument('--transfer', action='store_true', help='transfer learning flag') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--notest', action='store_true', help='only test final epoch') - parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--img-weights', action='store_true', help='select training images by weight') diff --git a/utils/utils.py b/utils/utils.py index e8fdbfa4..76b53ebc 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -312,7 +312,7 @@ class FocalLoss(nn.Module): return loss -def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model +def compute_loss(p, targets, model): # predictions, targets, model ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) tcls, tbox, indices, anchor_vec = build_targets(model, targets)