updates
This commit is contained in:
parent
f995d6093c
commit
fd3a6a4cba
|
@ -506,15 +506,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
box = xywh2xyxy(pred[:, :4])
|
||||||
|
|
||||||
# Multi-class
|
# Detections matrix nx6 (xyxy, conf, cls)
|
||||||
if multi_cls or conf_thres < 0.01:
|
if multi_cls or conf_thres < 0.01:
|
||||||
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
||||||
pred = torch.cat((pred[i, :4], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = pred[:, 5:].max(1)
|
conf, j = pred[:, 5:].max(1)
|
||||||
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
|
|
||||||
# Apply finite constraint
|
# Apply finite constraint
|
||||||
pred = pred[torch.isfinite(pred).all(1)]
|
pred = pred[torch.isfinite(pred).all(1)]
|
||||||
|
|
Loading…
Reference in New Issue