updates
This commit is contained in:
parent
378f08c6d5
commit
692b006f4d
|
@ -501,7 +501,7 @@ def build_targets(model, targets):
|
||||||
return tcls, tbox, indices, av
|
return tcls, tbox, indices, av
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=True, classes=None, agnostic=False):
|
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
||||||
"""
|
"""
|
||||||
Removes detections with lower object confidence score than 'conf_thres'
|
Removes detections with lower object confidence score than 'conf_thres'
|
||||||
Non-Maximum Suppression to further filter detections.
|
Non-Maximum Suppression to further filter detections.
|
||||||
|
@ -516,7 +516,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=Tru
|
||||||
method = 'vision_batch'
|
method = 'vision_batch'
|
||||||
batched = 'batch' in method # run once per image, all classes simultaneously
|
batched = 'batch' in method # run once per image, all classes simultaneously
|
||||||
nc = prediction[0].shape[1] - 5 # number of classes
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
multi_cls = multi_cls and (nc > 1) # allow multiple classes per anchor
|
multi_label &= nc > 1 # multiple labels per box
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
for image_i, pred in enumerate(prediction):
|
for image_i, pred in enumerate(prediction):
|
||||||
# Apply conf constraint
|
# Apply conf constraint
|
||||||
|
@ -536,7 +536,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=Tru
|
||||||
box = xywh2xyxy(pred[:, :4])
|
box = xywh2xyxy(pred[:, :4])
|
||||||
|
|
||||||
# Detections matrix nx6 (xyxy, conf, cls)
|
# Detections matrix nx6 (xyxy, conf, cls)
|
||||||
if multi_cls:
|
if multi_label:
|
||||||
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
||||||
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
|
|
Loading…
Reference in New Issue