updates
This commit is contained in:
parent
cdc382e313
commit
34d9392bac
|
@ -539,14 +539,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||||
# Apply finite constraint
|
# Apply finite constraint
|
||||||
pred = pred[torch.isfinite(pred).all(1)]
|
pred = pred[torch.isfinite(pred).all(1)]
|
||||||
|
|
||||||
# Get detections sorted by decreasing confidence scores
|
|
||||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
if method == 'vision_batch':
|
if method == 'vision_batch':
|
||||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
|
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Sort by confidence
|
||||||
|
if not method.startswith('vision'):
|
||||||
|
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
# All other NMS methods
|
# All other NMS methods
|
||||||
det_max = []
|
det_max = []
|
||||||
cls = pred[:, -1]
|
cls = pred[:, -1]
|
||||||
|
|
Loading…
Reference in New Issue