GIoU to default
This commit is contained in:
parent
7246dd855c
commit
b649a95c9a
3
test.py
3
test.py
|
@ -71,8 +71,7 @@ def test(
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
||||||
loss_i, _ = compute_loss(train_out, targets, model)
|
loss += compute_loss(train_out, targets, model)[0].item()
|
||||||
loss += loss_i.item()
|
|
||||||
|
|
||||||
# Run NMS
|
# Run NMS
|
||||||
output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres)
|
output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres)
|
||||||
|
|
4
train.py
4
train.py
|
@ -218,7 +218,7 @@ def train(
|
||||||
pred = model(imgs)
|
pred = model(imgs)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss, loss_items = compute_loss(pred, targets, model, giou_loss=opt.giou)
|
loss, loss_items = compute_loss(pred, targets, model, giou_loss=not opt.xywh)
|
||||||
if torch.isnan(loss):
|
if torch.isnan(loss):
|
||||||
print('WARNING: nan loss detected, ending training')
|
print('WARNING: nan loss detected, ending training')
|
||||||
return results
|
return results
|
||||||
|
@ -320,7 +320,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
||||||
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
||||||
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
||||||
parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss')
|
parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss')
|
||||||
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
||||||
parser.add_argument('--cloud-evolve', action='store_true', help='evolve hyperparameters from a cloud source')
|
parser.add_argument('--cloud-evolve', action='store_true', help='evolve hyperparameters from a cloud source')
|
||||||
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
||||||
|
|
|
@ -271,7 +271,7 @@ def wh_iou(box1, box2):
|
||||||
return inter_area / union_area # iou
|
return inter_area / union_area # iou
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, model
|
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
||||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||||
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
|
lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0])
|
||||||
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||||
|
@ -336,17 +336,17 @@ def build_targets(model, targets):
|
||||||
if nt:
|
if nt:
|
||||||
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
|
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
|
||||||
|
|
||||||
use_best = True
|
use_best_anchor = False
|
||||||
if use_best:
|
if use_best_anchor:
|
||||||
iou, a = iou.max(0) # best iou and anchor
|
iou, a = iou.max(0) # best iou and anchor
|
||||||
else:
|
else: # use all anchors
|
||||||
na = len(layer.anchor_vec) # number of anchors
|
na = len(layer.anchor_vec) # number of anchors
|
||||||
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
|
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
|
||||||
t = targets.repeat([na, 1])
|
t = targets.repeat([na, 1])
|
||||||
gwh = gwh.repeat([na, 1])
|
gwh = gwh.repeat([na, 1])
|
||||||
iou = iou.view(-1) # use all ious
|
iou = iou.view(-1) # use all ious
|
||||||
|
|
||||||
# reject below threshold ious (OPTIONAL, increases P, lowers R)
|
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
|
||||||
reject = True
|
reject = True
|
||||||
if reject:
|
if reject:
|
||||||
j = iou > iou_thres
|
j = iou > iou_thres
|
||||||
|
|
Loading…
Reference in New Issue