This commit is contained in:
Glenn Jocher 2019-12-23 11:05:55 -08:00
parent 06e88fec08
commit db26b08f5b
1 changed files with 5 additions and 5 deletions

View File

@ -491,13 +491,9 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
output = [None] * len(prediction) output = [None] * len(prediction)
for image_i, pred in enumerate(prediction): for image_i, pred in enumerate(prediction):
# Retain > conf # Apply conf constraint
pred = pred[pred[:, 4] > conf_thres] pred = pred[pred[:, 4] > conf_thres]
# Compute conf
torch.sigmoid_(pred[..., 5:])
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
# Apply width-height constraint # Apply width-height constraint
pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)] pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)]
@ -505,6 +501,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
if len(pred) == 0: if len(pred) == 0:
continue continue
# Compute conf
torch.sigmoid_(pred[..., 5:])
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(pred[:, :4]) box = xywh2xyxy(pred[:, :4])