This commit is contained in:
Glenn Jocher 2019-02-27 00:04:41 +01:00
parent eb6a4b5b84
commit 9a27339e04
1 changed files with 2 additions and 1 deletions

View File

@ -349,7 +349,8 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here
# v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here
v = pred[:, 4] > conf_thres
v = v.nonzero().squeeze()
if len(v.shape) == 0:
v = v.unsqueeze(0)