merge NMS full matrix

This commit is contained in:
Glenn Jocher 2020-04-02 18:53:40 -07:00
parent aa4591d7e9
commit 207c6fcff9
1 changed files with 2 additions and 2 deletions

View File

@ -555,12 +555,12 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
if method == 'merge': # Merge NMS (boxes merged using weighted mean) if method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
if n < 1E4: # update boxes if n < 1E4: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights # weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
# weights /= weights.sum(0) # normalize # weights /= weights.sum(0) # normalize
# x[:, :4] = torch.mm(weights.T, x[:, :4]) # x[:, :4] = torch.mm(weights.T, x[:, :4])
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]) # boxes(i,4) = w(i,n) * boxes(n,4) x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float() # merged boxes
elif method == 'vision': elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact