updates
This commit is contained in:
parent
ecce92d5d8
commit
a5677d3f90
|
@ -466,7 +466,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
Returns detections with shape:
|
Returns detections with shape:
|
||||||
(x1, y1, x2, y2, object_conf, class_conf, class)
|
(x1, y1, x2, y2, object_conf, class_conf, class)
|
||||||
"""
|
"""
|
||||||
|
# NMS method https://github.com/ultralytics/yolov3/issues/679 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
|
||||||
|
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
|
||||||
|
|
||||||
|
# Box constraints
|
||||||
min_wh, max_wh = 2, 10000 # (pixels) minimum and maximium box width and height
|
min_wh, max_wh = 2, 10000 # (pixels) minimum and maximium box width and height
|
||||||
|
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
|
@ -516,10 +519,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
pred = pred[(-pred[:, 4]).argsort()]
|
pred = pred[(-pred[:, 4]).argsort()]
|
||||||
|
|
||||||
# Set NMS method https://github.com/ultralytics/yolov3/issues/679
|
|
||||||
# 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
|
|
||||||
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
|
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
if method == 'VISION_BATCHED':
|
if method == 'VISION_BATCHED':
|
||||||
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
|
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
|
||||||
|
|
Loading…
Reference in New Issue