updates
This commit is contained in:
parent
fd949a8eb3
commit
674d0de170
|
@ -501,8 +501,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||
|
||||
# Detections ordered as (x1y1x2y2, conf, cls)
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
||||
# Expand
|
||||
expand = False
|
||||
if expand:
|
||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls)
|
||||
else:
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
pred = pred[(-pred[:, 4]).argsort()]
|
||||
|
|
Loading…
Reference in New Issue