Merge NMS update
This commit is contained in:
parent
94344f5bea
commit
eac07f9da3
|
@ -505,8 +505,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
# Box constraints
|
# Box constraints
|
||||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||||
|
|
||||||
method = 'vision_batch'
|
method = 'vision'
|
||||||
batched = 'batch' in method # run once per image, all classes simultaneously
|
|
||||||
nc = prediction[0].shape[1] - 5 # number of classes
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
multi_label &= nc > 1 # multiple labels per box
|
multi_label &= nc > 1 # multiple labels per box
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
|
@ -548,93 +547,26 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sort by confidence
|
# Sort by confidence
|
||||||
if not method.startswith('vision'):
|
# if method == 'fast_batch':
|
||||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
# pred = pred[pred[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
if batched:
|
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # classes
|
||||||
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS
|
|
||||||
boxes, scores = pred[:, :4].clone(), pred[:, 4]
|
boxes, scores = pred[:, :4].clone(), pred[:, 4]
|
||||||
boxes += c.view(-1, 1) * max_wh
|
boxes += c.view(-1, 1) * max_wh # offset boxes by class
|
||||||
if method == 'vision_batch':
|
if method == 'vision':
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
elif method == 'merge_batch': # Merge NMS
|
elif method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
iou = box_iou(boxes, boxes[i]).tril_() # upper triangular iou matrix
|
iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix
|
||||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||||
weights /= weights.sum(0)
|
weights /= weights.sum(0)
|
||||||
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||||
elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact
|
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||||
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
||||||
i = iou.max(0)[0] < iou_thres
|
i = iou.max(0)[0] < iou_thres
|
||||||
|
|
||||||
output[image_i] = pred[i]
|
output[image_i] = pred[i]
|
||||||
continue
|
|
||||||
|
|
||||||
# All other NMS methods
|
|
||||||
det_max = []
|
|
||||||
cls = pred[:, -1]
|
|
||||||
for c in cls.unique():
|
|
||||||
dc = pred[cls == c] # select class c
|
|
||||||
n = len(dc)
|
|
||||||
if n == 1:
|
|
||||||
det_max.append(dc) # No NMS required if only 1 prediction
|
|
||||||
continue
|
|
||||||
elif n > 500:
|
|
||||||
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
|
||||||
|
|
||||||
if method == 'or': # default
|
|
||||||
# METHOD1
|
|
||||||
# ind = list(range(len(dc)))
|
|
||||||
# while len(ind):
|
|
||||||
# j = ind[0]
|
|
||||||
# det_max.append(dc[j:j + 1]) # save highest conf detection
|
|
||||||
# reject = (bbox_iou(dc[j], dc[ind]) > iou_thres).nonzero()
|
|
||||||
# [ind.pop(i) for i in reversed(reject)]
|
|
||||||
|
|
||||||
# METHOD2
|
|
||||||
while dc.shape[0]:
|
|
||||||
det_max.append(dc[:1]) # save highest conf detection
|
|
||||||
if len(dc) == 1: # Stop if we're at the last detection
|
|
||||||
break
|
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
|
||||||
dc = dc[1:][iou < iou_thres] # remove ious > threshold
|
|
||||||
|
|
||||||
elif method == 'and': # requires overlap, single boxes erased
|
|
||||||
while len(dc) > 1:
|
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
|
||||||
if iou.max() > 0.5:
|
|
||||||
det_max.append(dc[:1])
|
|
||||||
dc = dc[1:][iou < iou_thres] # remove ious > threshold
|
|
||||||
|
|
||||||
elif method == 'merge': # weighted mixture box
|
|
||||||
while len(dc):
|
|
||||||
if len(dc) == 1:
|
|
||||||
det_max.append(dc)
|
|
||||||
break
|
|
||||||
i = bbox_iou(dc[0], dc) > iou_thres # iou with other boxes
|
|
||||||
weights = dc[i, 4:5]
|
|
||||||
dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
|
|
||||||
det_max.append(dc[:1])
|
|
||||||
dc = dc[i == 0]
|
|
||||||
|
|
||||||
elif method == 'soft': # soft-NMS https://arxiv.org/abs/1704.04503
|
|
||||||
sigma = 0.5 # soft-nms sigma parameter
|
|
||||||
while len(dc):
|
|
||||||
if len(dc) == 1:
|
|
||||||
det_max.append(dc)
|
|
||||||
break
|
|
||||||
det_max.append(dc[:1])
|
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
|
||||||
dc = dc[1:]
|
|
||||||
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
|
|
||||||
dc = dc[dc[:, 4] > conf_thres] # https://github.com/ultralytics/yolov3/issues/362
|
|
||||||
|
|
||||||
if len(det_max):
|
|
||||||
det_max = torch.cat(det_max) # concatenate
|
|
||||||
output[image_i] = det_max[det_max[:, 4].argsort(descending=True)] # sort
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_yolo_layers(model):
|
def get_yolo_layers(model):
|
||||||
|
|
Loading…
Reference in New Issue