diff --git a/utils/utils.py b/utils/utils.py index c2ac95fc..9eee669c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -555,10 +555,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores if method == 'merge': # Merge NMS (boxes merged using weighted mean) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) - iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix + iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix weights = (iou > iou_thres) * scores.view(-1, 1) - weights /= weights.sum(0) + 1E-6 - x[i, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) + weights /= weights.sum(0) + x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) elif method == 'vision': i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact