updates
This commit is contained in:
parent
9048d96c71
commit
ce9a2cb9d2
|
@ -474,24 +474,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
|
|||
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
# Duplicate ambiguous
|
||||
# b = pred[pred[:, 5:].sum(1) > 1.1]
|
||||
# if len(b):
|
||||
# b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0
|
||||
# pred = torch.cat((pred, b), 0)
|
||||
|
||||
# Multiply conf by class conf to get combined confidence
|
||||
conf, cls = pred[:, 4:].max(1)
|
||||
|
||||
# # Merge classes (optional)
|
||||
# cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
|
||||
#
|
||||
# # Remove classes (optional)
|
||||
# pred[cls != 2, 4] = 0.0
|
||||
# Remove rows
|
||||
pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold
|
||||
|
||||
# Select only suitable predictions
|
||||
i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(
|
||||
pred).all(1)
|
||||
i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1)
|
||||
pred = pred[i]
|
||||
|
||||
# If none are remaining => process next image
|
||||
|
@ -505,11 +492,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
|
|||
if multi_cls:
|
||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
else:
|
||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls)
|
||||
else: # best class only
|
||||
conf, j = pred[:, 4:].max(1)
|
||||
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
pred = pred[(-pred[:, 4]).argsort()]
|
||||
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||
|
||||
# Batched NMS
|
||||
if method == 'vision_batch':
|
||||
|
|
Loading…
Reference in New Issue