nms torch.mm() update
This commit is contained in:
parent
ae2bc020eb
commit
9f04e175f6
|
@ -564,7 +564,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
# weights /= weights.sum(0) # normalize
|
# weights /= weights.sum(0) # normalize
|
||||||
# x[:, :4] = torch.mm(weights.T, x[:, :4])
|
# x[:, :4] = torch.mm(weights.T, x[:, :4])
|
||||||
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
|
weights = (box_iou(boxes[i], boxes) > iou_thres) * scores[None] # box weights
|
||||||
x[i, :4] = torch.mm(weights / weights.sum(1, keepdim=True), x[:, :4]).float() # merged boxes
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||||
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
||||||
pass
|
pass
|
||||||
elif method == 'vision':
|
elif method == 'vision':
|
||||||
|
|
Loading…
Reference in New Issue