updates
This commit is contained in:
parent
43230c48bf
commit
ce0b414677
3
train.py
3
train.py
|
@ -266,7 +266,7 @@ def train(cfg,
|
|||
pred = model(imgs)
|
||||
|
||||
# Compute loss
|
||||
loss, loss_items = compute_loss(pred, targets, model, giou_loss=not opt.xywh)
|
||||
loss, loss_items = compute_loss(pred, targets, model)
|
||||
if torch.isnan(loss):
|
||||
print('WARNING: nan loss detected, ending training')
|
||||
return results
|
||||
|
@ -368,7 +368,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
|
||||
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
||||
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
||||
parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss')
|
||||
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
||||
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
||||
parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
|
||||
|
|
|
@ -312,7 +312,7 @@ class FocalLoss(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
||||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||
|
|
Loading…
Reference in New Issue