updates
This commit is contained in:
parent
9048d96c71
commit
ce9a2cb9d2
|
@ -474,24 +474,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
|
||||||
|
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
for image_i, pred in enumerate(prediction):
|
for image_i, pred in enumerate(prediction):
|
||||||
# Duplicate ambiguous
|
# Remove rows
|
||||||
# b = pred[pred[:, 5:].sum(1) > 1.1]
|
pred = pred[(pred[:, 4:] > conf_thres).any(1)] # retain above threshold
|
||||||
# if len(b):
|
|
||||||
# b[range(len(b)), 5 + b[:, 5:].argmax(1)] = 0
|
|
||||||
# pred = torch.cat((pred, b), 0)
|
|
||||||
|
|
||||||
# Multiply conf by class conf to get combined confidence
|
|
||||||
conf, cls = pred[:, 4:].max(1)
|
|
||||||
|
|
||||||
# # Merge classes (optional)
|
|
||||||
# cls[(cls.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2
|
|
||||||
#
|
|
||||||
# # Remove classes (optional)
|
|
||||||
# pred[cls != 2, 4] = 0.0
|
|
||||||
|
|
||||||
# Select only suitable predictions
|
# Select only suitable predictions
|
||||||
i = (conf > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(
|
i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1)
|
||||||
pred).all(1)
|
|
||||||
pred = pred[i]
|
pred = pred[i]
|
||||||
|
|
||||||
# If none are remaining => process next image
|
# If none are remaining => process next image
|
||||||
|
@ -505,11 +492,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal
|
||||||
if multi_cls:
|
if multi_cls:
|
||||||
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
|
||||||
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else:
|
else: # best class only
|
||||||
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # (xyxy, conf, cls)
|
conf, j = pred[:, 4:].max(1)
|
||||||
|
pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls)
|
||||||
|
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
pred = pred[(-pred[:, 4]).argsort()]
|
pred = pred[pred[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
if method == 'vision_batch':
|
if method == 'vision_batch':
|
||||||
|
|
Loading…
Reference in New Issue