This commit is contained in:
Glenn Jocher 2019-12-23 10:31:37 -08:00
parent fd3a6a4cba
commit 209cc9e124
1 changed files with 3 additions and 3 deletions

View File

@ -494,14 +494,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
# Retain > conf # Retain > conf
pred = pred[pred[:, 4] > conf_thres] pred = pred[pred[:, 4] > conf_thres]
# compute conf # Compute conf
torch.sigmoid_(pred[..., 5:]) torch.sigmoid_(pred[..., 5:])
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
# Apply width-height constraint # Apply width-height constraint
pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)] pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)]
# If none are remaining => process next image # If none remain process next image
if len(pred) == 0: if len(pred) == 0:
continue continue
@ -528,7 +528,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
output[image_i] = pred[i] output[image_i] = pred[i]
continue continue
# Non-maximum suppression # All other NMS methods
det_max = [] det_max = []
cls = pred[:, -1] cls = pred[:, -1]
for c in cls.unique(): for c in cls.unique():