nms torch.mm() update

This commit is contained in:
Glenn Jocher 2020-05-10 11:26:37 -07:00
parent ae2bc020eb
commit 9f04e175f6
1 changed files with 1 additions and 1 deletions

View File

@ -564,7 +564,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
# weights /= weights.sum(0) # normalize # weights /= weights.sum(0) # normalize
# x[:, :4] = torch.mm(weights.T, x[:, :4]) # x[:, :4] = torch.mm(weights.T, x[:, :4])
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float() # merged boxes x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139 except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
pass pass
elif method == 'vision': elif method == 'vision':