From ce9a2cb9d217d22440f4456815aa26fdba8a73aa Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 19 Dec 2019 19:23:09 -0800 Subject: [PATCH] updates --- utils/utils.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 5a1a4912..6ffdfee8 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -474,24 +474,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal output = [None] * len(prediction) for image_i, pred in enumerate(prediction): - # Duplicate ambiguous - # b = pred[pred[:, 5:].sum(1) > 1.1] - # if len(b): - # b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0 - # pred = torch.cat((pred, b), 0) - - # Multiply conf by class conf to get combined confidence - conf, cls = pred[:, 4:].max(1) - - # # Merge classes (optional) - # cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2 - # - # # Remove classes (optional) - # pred[cls != 2, 4] = 0.0 + # Remove rows + pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold # Select only suitable predictions - i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite( - pred).all(1) + i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1) pred = pred[i] # If none are remaining => process next image @@ -505,11 +492,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal if multi_cls: i, j = (pred[:, 4:] > conf_thres).nonzero().t() pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1) - else: - pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls) + else: # best class only + conf, j = pred[:, 4:].max(1) + pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls) # Get detections sorted by decreasing confidence scores - pred = pred[(-pred[:, 4]).argsort()] + pred = pred[pred[:, 4].argsort(descending=True)] # Batched NMS if method == 'vision_batch':