updates
This commit is contained in:
parent
a5160b44ca
commit
f995d6093c
|
@ -491,16 +491,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||||
|
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
for image_i, pred in enumerate(prediction):
|
for image_i, pred in enumerate(prediction):
|
||||||
# Remove rows
|
# Retain > conf
|
||||||
pred = pred[pred[:, 4] > conf_thres] # retain above threshold
|
pred = pred[pred[:, 4] > conf_thres]
|
||||||
|
|
||||||
# compute conf
|
# compute conf
|
||||||
torch.sigmoid_(pred[..., 5:])
|
torch.sigmoid_(pred[..., 5:])
|
||||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
||||||
|
|
||||||
# Apply width-height constraint
|
# Apply width-height constraint
|
||||||
i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1)
|
pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)]
|
||||||
pred = pred[i]
|
|
||||||
|
|
||||||
# If none are remaining => process next image
|
# If none are remaining => process next image
|
||||||
if len(pred) == 0:
|
if len(pred) == 0:
|
||||||
|
@ -517,6 +516,9 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||||
conf, j = pred[:, 5:].max(1)
|
conf, j = pred[:, 5:].max(1)
|
||||||
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
||||||
|
|
||||||
|
# Apply finite constraint
|
||||||
|
pred = pred[torch.isfinite(pred).all(1)]
|
||||||
|
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue