[conf > conf_thres] update
This commit is contained in:
parent
5b572681ff
commit
0c7d7427e4
|
@ -482,7 +482,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
for xi, x in enumerate(prediction): # image index, image inference
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
# Apply constraints
|
# Apply constraints
|
||||||
x = x[x[:, 4] > conf_thres] # confidence
|
x = x[x[:, 4] > conf_thres] # confidence
|
||||||
# x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
|
x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
|
||||||
|
|
||||||
# If none remain process next image
|
# If none remain process next image
|
||||||
if not x.shape[0]:
|
if not x.shape[0]:
|
||||||
|
@ -500,7 +500,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = x[:, 5:].max(1)
|
conf, j = x[:, 5:].max(1)
|
||||||
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres]
|
||||||
|
|
||||||
# Filter by class
|
# Filter by class
|
||||||
if classes:
|
if classes:
|
||||||
|
|
Loading…
Reference in New Issue