diff --git a/utils/utils.py b/utils/utils.py index 54adc05e..c396be03 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -459,15 +459,15 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): +def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. Returns detections with shape: (x1, y1, x2, y2, object_conf, conf, class) """ - # NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED' - method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest + # NMS method https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch' + # method = 'merge' if conf_thres <= 0.01 else 'vision' # MERGE is highest mAP, VISION is fastest # Box constraints min_wh, max_wh = 2, 10000 # (pixels) minimum and maximium box width and height @@ -501,19 +501,18 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): # Box (center x, center y, width, height) to (x1, y1, x2, y2) pred[:, :4] = xywh2xyxy(pred[:, :4]) - # Expand - expand = False - if expand: + # Multi-class + if multi_cls: i, j = (pred[:, 4:] > conf_thres).nonzero().t() - pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls) + pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1) else: - pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) + pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls) # Get detections sorted by decreasing confidence scores pred = pred[(-pred[:, 4]).argsort()] # Batched NMS - if method == 'VISION_BATCHED': + if method == 'vision_batch': i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4], scores=pred[:, 4], idxs=pred[:, 6], @@ -532,11 +531,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): elif n > 500: dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117 - if method == 'VISION': + if method == 'vision': i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres) det_max.append(dc[i]) - elif method == 'OR': # default + elif method == 'or': # default # METHOD1 # ind = list(range(len(dc))) # while len(ind): @@ -553,14 +552,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes dc = dc[1:][iou < nms_thres] # remove ious > threshold - elif method == 'AND': # requires overlap, single boxes erased + elif method == 'and': # requires overlap, single boxes erased while len(dc) > 1: iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes if iou.max() > 0.5: det_max.append(dc[:1]) dc = dc[1:][iou < nms_thres] # remove ious > threshold - elif method == 'MERGE': # weighted mixture box + elif method == 'merge': # weighted mixture box while len(dc): if len(dc) == 1: det_max.append(dc) @@ -571,7 +570,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): det_max.append(dc[:1]) dc = dc[i == 0] - elif method == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503 + elif method == 'soft': # soft-NMS https://arxiv.org/abs/1704.04503 sigma = 0.5 # soft-nms sigma parameter while len(dc): if len(dc) == 1: