updates
This commit is contained in:
parent
674d0de170
commit
aaaaa06156
|
@ -459,15 +459,15 @@ def build_targets(model, targets):
|
|||
return tcls, tbox, indices, av
|
||||
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
Returns detections with shape:
|
||||
(x1, y1, x2, y2, object_conf, conf, class)
|
||||
"""
|
||||
# NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
|
||||
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
|
||||
# NMS method https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'
|
||||
# method = 'merge' if conf_thres <= 0.01 else 'vision' # MERGE is highest mAP, VISION is fastest
|
||||
|
||||
# Box constraints
|
||||
min_wh, max_wh = 2, 10000 # (pixels) minimum and maximium box width and height
|
||||
|
@ -501,19 +501,18 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||
|
||||
# Expand
|
||||
expand = False
|
||||
if expand:
|
||||
# Multi-class
|
||||
if multi_cls:
|
||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls)
|
||||
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
else:
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls)
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
pred = pred[(-pred[:, 4]).argsort()]
|
||||
|
||||
# Batched NMS
|
||||
if method == 'VISION_BATCHED':
|
||||
if method == 'vision_batch':
|
||||
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
|
||||
scores=pred[:, 4],
|
||||
idxs=pred[:, 6],
|
||||
|
@ -532,11 +531,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
elif n > 500:
|
||||
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
||||
|
||||
if method == 'VISION':
|
||||
if method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres)
|
||||
det_max.append(dc[i])
|
||||
|
||||
elif method == 'OR': # default
|
||||
elif method == 'or': # default
|
||||
# METHOD1
|
||||
# ind = list(range(len(dc)))
|
||||
# while len(ind):
|
||||
|
@ -553,14 +552,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
elif method == 'AND': # requires overlap, single boxes erased
|
||||
elif method == 'and': # requires overlap, single boxes erased
|
||||
while len(dc) > 1:
|
||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||
if iou.max() > 0.5:
|
||||
det_max.append(dc[:1])
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
elif method == 'MERGE': # weighted mixture box
|
||||
elif method == 'merge': # weighted mixture box
|
||||
while len(dc):
|
||||
if len(dc) == 1:
|
||||
det_max.append(dc)
|
||||
|
@ -571,7 +570,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
det_max.append(dc[:1])
|
||||
dc = dc[i == 0]
|
||||
|
||||
elif method == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
|
||||
elif method == 'soft': # soft-NMS https://arxiv.org/abs/1704.04503
|
||||
sigma = 0.5 # soft-nms sigma parameter
|
||||
while len(dc):
|
||||
if len(dc) == 1:
|
||||
|
|
Loading…
Reference in New Issue