updates
This commit is contained in:
parent
9309d35478
commit
2e1c415e59
|
@ -501,10 +501,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
if method == 'vision_batch':
|
if method == 'vision_batch':
|
||||||
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
|
i = torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)
|
||||||
scores=pred[:, 4],
|
|
||||||
idxs=pred[:, 6],
|
|
||||||
iou_threshold=nms_thres)
|
|
||||||
output[image_i] = pred[i]
|
output[image_i] = pred[i]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue