diff --git a/utils/utils.py b/utils/utils.py index 820bd936..d3a7ecf5 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -457,10 +457,10 @@ def build_targets(model, targets): t, a = targets, [] gwh = t[:, 4:6] * ng if nt: - iou = wh_iou(anchor_vec, gwh) + iou = wh_iou(anchor_vec, gwh) # iou(3,n) = wh_iou(anchor_vec(3,2), gwh(n,2)) if use_all_anchors: - na = len(anchor_vec) # number of anchors + na = anchor_vec.shape[0] # number of anchors a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) t = targets.repeat([na, 1]) gwh = gwh.repeat([na, 1])