updates
This commit is contained in:
parent
fd949a8eb3
commit
674d0de170
|
@ -501,8 +501,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||||
|
|
||||||
# Detections ordered as (x1y1x2y2, conf, cls)
|
# Expand
|
||||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
expand = False
|
||||||
|
if expand:
|
||||||
|
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||||
|
pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls)
|
||||||
|
else:
|
||||||
|
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
||||||
|
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
pred = pred[(-pred[:, 4]).argsort()]
|
pred = pred[(-pred[:, 4]).argsort()]
|
||||||
|
|
Loading…
Reference in New Issue