This commit is contained in:
Glenn Jocher 2019-08-18 02:04:49 +02:00
parent 43230c48bf
commit ce0b414677
2 changed files with 2 additions and 3 deletions

View File

@ -266,7 +266,7 @@ def train(cfg,
pred = model(imgs) pred = model(imgs)
# Compute loss # Compute loss
loss, loss_items = compute_loss(pred, targets, model, giou_loss=not opt.xywh) loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss): if torch.isnan(loss):
print('WARNING: nan loss detected, ending training') print('WARNING: nan loss detected, ending training')
return results return results
@ -368,7 +368,6 @@ if __name__ == '__main__':
parser.add_argument('--transfer', action='store_true', help='transfer learning flag') parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--img-weights', action='store_true', help='select training images by weight') parser.add_argument('--img-weights', action='store_true', help='select training images by weight')

View File

@ -312,7 +312,7 @@ class FocalLoss(nn.Module):
return loss return loss
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model def compute_loss(p, targets, model): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
tcls, tbox, indices, anchor_vec = build_targets(model, targets) tcls, tbox, indices, anchor_vec = build_targets(model, targets)