diff --git a/detect.py b/detect.py index b9ddbf0f..09bb7311 100644 --- a/detect.py +++ b/detect.py @@ -86,7 +86,7 @@ def detect(save_img=False): pred = pred.float() # Apply NMS - pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes) + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) # Apply Classifier if classify: @@ -169,6 +169,7 @@ if __name__ == '__main__': parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='display results') parser.add_argument('--classes', nargs='+', type=int, help='filter by class') + parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') opt = parser.parse_args() print(opt) diff --git a/utils/utils.py b/utils/utils.py index 70e40667..91dc62a9 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -499,7 +499,7 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None): +def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, classes=None, agnostic=False): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. @@ -511,6 +511,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru # Box constraints min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + method = 'vision_batch' output = [None] * len(prediction) for image_i, pred in enumerate(prediction): # Apply conf constraint @@ -548,7 +549,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru # Batched NMS if method == 'vision_batch': - output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], iou_thres)] + c = j * 0 if agnostic else j # class-agnostic NMS + output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)] continue # Sort by confidence