updates
This commit is contained in:
parent
f8aab0e952
commit
8b2f85c290
|
@ -461,6 +461,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
class_conf, class_pred = pred[:, 5:].max(1)
|
class_conf, class_pred = pred[:, 5:].max(1)
|
||||||
pred[:, 4] *= class_conf
|
pred[:, 4] *= class_conf
|
||||||
|
|
||||||
|
# # Merge classes (optional)
|
||||||
|
# class_pred[(class_pred.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
|
||||||
|
#
|
||||||
|
# # Remove classes (optional)
|
||||||
|
# pred[class_pred != 2, 4] = 0.0
|
||||||
|
|
||||||
# Select only suitable predictions
|
# Select only suitable predictions
|
||||||
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & torch.isfinite(pred).all(1)
|
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & torch.isfinite(pred).all(1)
|
||||||
pred = pred[i]
|
pred = pred[i]
|
||||||
|
|
Loading…
Reference in New Issue