nms torch.mm() update
This commit is contained in:
parent
ae2bc020eb
commit
9f04e175f6
|
@ -564,7 +564,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
# weights /= weights.sum(0) # normalize
|
||||
# x[:, :4] = torch.mm(weights.T, x[:, :4])
|
||||
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float() # merged boxes
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
||||
pass
|
||||
elif method == 'vision':
|
||||
|
|
Loading…
Reference in New Issue