diff --git a/utils/utils.py b/utils/utils.py index d7b99549..134cf6ef 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -558,9 +558,15 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T boxes += c.view(-1, 1) * max_wh if method == 'vision_batch': i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) + elif method == 'merge_batch': # Merge NMS + i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) + iou = box_iou(boxes, boxes[i]).tril_() # upper triangular iou matrix + weights = (iou > conf_thres) * scores.view(-1, 1) + weights /= weights.sum(0) + pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix - i = iou.max(dim=0)[0] < iou_thres + i = iou.max(0)[0] < iou_thres output[image_i] = pred[i] continue @@ -577,10 +583,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T elif n > 500: dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117 - if method == 'vision': - det_max.append(dc[torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], iou_thres)]) - - elif method == 'or': # default + if method == 'or': # default # METHOD1 # ind = list(range(len(dc))) # while len(ind): @@ -629,7 +632,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T if len(det_max): det_max = torch.cat(det_max) # concatenate - output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort + output[image_i] = det_max[det_max[:, 4].argsort(descending=True)] # sort return output