updates
This commit is contained in:
parent
cdc382e313
commit
34d9392bac
|
@ -539,14 +539,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
|||
# Apply finite constraint
|
||||
pred = pred[torch.isfinite(pred).all(1)]
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||
|
||||
# Batched NMS
|
||||
if method == 'vision_batch':
|
||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
if not method.startswith('vision'):
|
||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||
|
||||
# All other NMS methods
|
||||
det_max = []
|
||||
cls = pred[:, -1]
|
||||
|
|
Loading…
Reference in New Issue