updates
This commit is contained in:
parent
308f7c8563
commit
166f8c0e53
|
@ -501,7 +501,7 @@ def build_targets(model, targets):
|
||||||
return tcls, tbox, indices, av
|
return tcls, tbox, indices, av
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, classes=None, agnostic=False):
|
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=True, classes=None, agnostic=False):
|
||||||
"""
|
"""
|
||||||
Removes detections with lower object confidence score than 'conf_thres'
|
Removes detections with lower object confidence score than 'conf_thres'
|
||||||
Non-Maximum Suppression to further filter detections.
|
Non-Maximum Suppression to further filter detections.
|
||||||
|
@ -513,7 +513,8 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||||
# Box constraints
|
# Box constraints
|
||||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||||
|
|
||||||
method = 'vision_batch'
|
method = 'fast_batch'
|
||||||
|
batched = 'batch' in method # run once per image, all classes simultaneously
|
||||||
nc = prediction[0].shape[1] - 5 # number of classes
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
multi_cls = multi_cls and (nc > 1) # allow multiple classes per anchor
|
multi_cls = multi_cls and (nc > 1) # allow multiple classes per anchor
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
|
@ -550,16 +551,24 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
|
||||||
if not pred.shape[0]:
|
if not pred.shape[0]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Batched NMS
|
|
||||||
if method == 'vision_batch':
|
|
||||||
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS
|
|
||||||
output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Sort by confidence
|
# Sort by confidence
|
||||||
if not method.startswith('vision'):
|
if not method.startswith('vision'):
|
||||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
|
# Batched NMS
|
||||||
|
if batched:
|
||||||
|
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS
|
||||||
|
boxes, scores = pred[:, :4].clone(), pred[:, 4]
|
||||||
|
if method == 'vision_batch':
|
||||||
|
i = torchvision.ops.boxes.batched_nms(boxes, scores, c, iou_thres)
|
||||||
|
elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact
|
||||||
|
boxes += c.view(-1, 1) * max_wh
|
||||||
|
iou = box_iou(boxes, boxes).triu_(diagonal=1) # zero upper triangle iou matrix
|
||||||
|
i = iou.max(dim=0)[0] < iou_thres
|
||||||
|
|
||||||
|
output[image_i] = pred[i]
|
||||||
|
continue
|
||||||
|
|
||||||
# All other NMS methods
|
# All other NMS methods
|
||||||
det_max = []
|
det_max = []
|
||||||
cls = pred[:, -1]
|
cls = pred[:, -1]
|
||||||
|
|
Loading…
Reference in New Issue