Merge NMS update
This commit is contained in:
parent
eac07f9da3
commit
171b4129b5
|
@ -495,78 +495,75 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
||||||
"""
|
"""
|
||||||
Removes detections with lower object confidence score than 'conf_thres'
|
Performs Non-Maximum Suppression on inference results
|
||||||
Non-Maximum Suppression to further filter detections.
|
|
||||||
Returns detections with shape:
|
Returns detections with shape:
|
||||||
(x1, y1, x2, y2, object_conf, conf, class)
|
nx6 (x1, y1, x2, y2, conf, cls)
|
||||||
"""
|
"""
|
||||||
# NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'
|
|
||||||
|
|
||||||
# Box constraints
|
# Box constraints
|
||||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||||
|
|
||||||
method = 'vision'
|
method = 'merge'
|
||||||
nc = prediction[0].shape[1] - 5 # number of classes
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
multi_label &= nc > 1 # multiple labels per box
|
multi_label &= nc > 1 # multiple labels per box
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
for image_i, pred in enumerate(prediction):
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
# Apply conf constraint
|
# Apply conf constraint
|
||||||
pred = pred[pred[:, 4] > conf_thres]
|
x = x[x[:, 4] > conf_thres]
|
||||||
|
|
||||||
# Apply width-height constraint
|
# Apply width-height constraint
|
||||||
pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)]
|
x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]
|
||||||
|
|
||||||
# If none remain process next image
|
# If none remain process next image
|
||||||
if not pred.shape[0]:
|
if not x.shape[0]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Compute conf
|
# Compute conf
|
||||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf
|
||||||
|
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
box = xywh2xyxy(pred[:, :4])
|
box = xywh2xyxy(x[:, :4])
|
||||||
|
|
||||||
# Detections matrix nx6 (xyxy, conf, cls)
|
# Detections matrix nx6 (xyxy, conf, cls)
|
||||||
if multi_label:
|
if multi_label:
|
||||||
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
i, j = (x[:, 5:] > conf_thres).nonzero().t()
|
||||||
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = pred[:, 5:].max(1)
|
conf, j = x[:, 5:].max(1)
|
||||||
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||||
|
|
||||||
# Filter by class
|
# Filter by class
|
||||||
if classes:
|
if classes:
|
||||||
pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
|
x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
|
||||||
|
|
||||||
# Apply finite constraint
|
# Apply finite constraint
|
||||||
if not torch.isfinite(pred).all():
|
if not torch.isfinite(x).all():
|
||||||
pred = pred[torch.isfinite(pred).all(1)]
|
x = x[torch.isfinite(x).all(1)]
|
||||||
|
|
||||||
# If none remain process next image
|
# If none remain process next image
|
||||||
if not pred.shape[0]:
|
if not x.shape[0]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sort by confidence
|
# Sort by confidence
|
||||||
# if method == 'fast_batch':
|
# if method == 'fast_batch':
|
||||||
# pred = pred[pred[:, 4].argsort(descending=True)]
|
# x = x[x[:, 4].argsort(descending=True)]
|
||||||
|
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # classes
|
c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
|
||||||
boxes, scores = pred[:, :4].clone(), pred[:, 4]
|
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||||
boxes += c.view(-1, 1) * max_wh # offset boxes by class
|
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||||
if method == 'vision':
|
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
|
||||||
elif method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix
|
iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix
|
||||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||||
weights /= weights.sum(0)
|
weights /= weights.sum(0)
|
||||||
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
x[i, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||||
|
elif method == 'vision':
|
||||||
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||||
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
||||||
i = iou.max(0)[0] < iou_thres
|
i = iou.max(0)[0] < iou_thres
|
||||||
|
|
||||||
output[image_i] = pred[i]
|
output[xi] = x[i]
|
||||||
|
|
||||||
|
|
||||||
def get_yolo_layers(model):
|
def get_yolo_layers(model):
|
||||||
|
|
Loading…
Reference in New Issue