diff --git a/utils/utils.py b/utils/utils.py index 2240403d..68aab83e 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -482,7 +482,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T for xi, x in enumerate(prediction): # image index, image inference # Apply constraints x = x[x[:, 4] > conf_thres] # confidence - # x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height + x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height # If none remain process next image if not x.shape[0]: @@ -500,7 +500,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only conf, j = x[:, 5:].max(1) - x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) + x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres] # Filter by class if classes: