This commit is contained in:
Glenn Jocher 2019-12-19 19:23:09 -08:00
parent 9048d96c71
commit ce9a2cb9d2
1 changed files with 7 additions and 19 deletions

View File

@ -474,24 +474,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
output = [None] * len(prediction)
for image_i, pred in enumerate(prediction):
# Duplicate ambiguous
# b = pred[pred[:, 5:].sum(1) > 1.1]
# if len(b):
# b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0
# pred = torch.cat((pred, b), 0)
# Multiply conf by class conf to get combined confidence
conf, cls = pred[:, 4:].max(1)
# # Merge classes (optional)
# cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
#
# # Remove classes (optional)
# pred[cls != 2, 4] = 0.0
# Remove rows
pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold
# Select only suitable predictions
i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(
pred).all(1)
i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1)
pred = pred[i]
# If none are remaining => process next image
@ -505,11 +492,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
if multi_cls:
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
else:
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls)
else: # best class only
conf, j = pred[:, 4:].max(1)
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
# Get detections sorted by decreasing confidence scores
pred = pred[(-pred[:, 4]).argsort()]
pred = pred[pred[:, 4].argsort(descending=True)]
# Batched NMS
if method == 'vision_batch':