updates
This commit is contained in:
parent
ce9a2cb9d2
commit
9309d35478
|
@ -459,7 +459,7 @@ def build_targets(model, targets):
|
||||||
return tcls, tbox, indices, av
|
return tcls, tbox, indices, av
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=False, method='vision'):
|
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'):
|
||||||
"""
|
"""
|
||||||
Removes detections with lower object confidence score than 'conf_thres'
|
Removes detections with lower object confidence score than 'conf_thres'
|
||||||
Non-Maximum Suppression to further filter detections.
|
Non-Maximum Suppression to further filter detections.
|
||||||
|
@ -489,7 +489,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
|
||||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||||
|
|
||||||
# Multi-class
|
# Multi-class
|
||||||
if multi_cls:
|
if multi_cls or conf_thres < 0.01:
|
||||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||||
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
|
|
Loading…
Reference in New Issue