From 166f8c0e5378c68e718a89420efbab66d3cf9f18 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 4 Mar 2020 00:07:19 -0800 Subject: [PATCH] updates --- utils/utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 62e2cf49..607cfb1b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -501,7 +501,7 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, classes=None, agnostic=False): +def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=True, classes=None, agnostic=False): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. @@ -513,7 +513,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru # Box constraints min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height - method = 'vision_batch' + method = 'fast_batch' + batched = 'batch' in method # run once per image, all classes simultaneously nc = prediction[0].shape[1] - 5 # number of classes multi_cls = multi_cls and (nc > 1) # allow multiple classes per anchor output = [None] * len(prediction) @@ -550,16 +551,24 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru if not pred.shape[0]: continue - # Batched NMS - if method == 'vision_batch': - c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS - output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)] - continue - # Sort by confidence if not method.startswith('vision'): pred = pred[pred[:, 4].argsort(descending=True)] + # Batched NMS + if batched: + c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS + boxes, scores = pred[:, :4].clone(), pred[:, 4] + if method == 'vision_batch': + i = torchvision.ops.boxes.batched_nms(boxes, scores, c, iou_thres) + elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact + boxes += c.view(-1, 1) * max_wh + iou = box_iou(boxes, boxes).triu_(diagonal=1) # zero upper triangle iou matrix + i = iou.max(dim=0)[0] < iou_thres + + output[image_i] = pred[i] + continue + # All other NMS methods det_max = [] cls = pred[:, -1]