This commit is contained in:
Glenn Jocher 2019-12-25 14:47:50 -08:00
parent cdc382e313
commit 34d9392bac
1 changed files with 4 additions and 3 deletions

View File

@ -539,14 +539,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
# Apply finite constraint
pred = pred[torch.isfinite(pred).all(1)]
# Get detections sorted by decreasing confidence scores
pred = pred[pred[:, 4].argsort(descending=True)]
# Batched NMS
if method == 'vision_batch':
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
continue
# Sort by confidence
if not method.startswith('vision'):
pred = pred[pred[:, 4].argsort(descending=True)]
# All other NMS methods
det_max = []
cls = pred[:, -1]