Merge NMS update

This commit is contained in:
Glenn Jocher 2020-03-26 12:33:12 -07:00
parent eac07f9da3
commit 171b4129b5
1 changed files with 25 additions and 28 deletions

View File

@ -495,78 +495,75 @@ def build_targets(model, targets):
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False): def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
""" """
Removes detections with lower object confidence score than 'conf_thres' Performs Non-Maximum Suppression on inference results
Non-Maximum Suppression to further filter detections.
Returns detections with shape: Returns detections with shape:
(x1, y1, x2, y2, object_conf, conf, class) nx6 (x1, y1, x2, y2, conf, cls)
""" """
# NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'
# Box constraints # Box constraints
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
method = 'vision' method = 'merge'
nc = prediction[0].shape[1] - 5 # number of classes nc = prediction[0].shape[1] - 5 # number of classes
multi_label &= nc > 1 # multiple labels per box multi_label &= nc > 1 # multiple labels per box
output = [None] * len(prediction) output = [None] * len(prediction)
for image_i, pred in enumerate(prediction): for xi, x in enumerate(prediction): # image index, image inference
# Apply conf constraint # Apply conf constraint
pred = pred[pred[:, 4] > conf_thres] x = x[x[:, 4] > conf_thres]
# Apply width-height constraint # Apply width-height constraint
pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)] x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]
# If none remain process next image # If none remain process next image
if not pred.shape[0]: if not x.shape[0]:
continue continue
# Compute conf # Compute conf
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(pred[:, :4]) box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls) # Detections matrix nx6 (xyxy, conf, cls)
if multi_label: if multi_label:
i, j = (pred[:, 5:] > conf_thres).nonzero().t() i, j = (x[:, 5:] > conf_thres).nonzero().t()
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
else: # best class only else: # best class only
conf, j = pred[:, 5:].max(1) conf, j = x[:, 5:].max(1)
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
# Filter by class # Filter by class
if classes: if classes:
pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)] x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
# Apply finite constraint # Apply finite constraint
if not torch.isfinite(pred).all(): if not torch.isfinite(x).all():
pred = pred[torch.isfinite(pred).all(1)] x = x[torch.isfinite(x).all(1)]
# If none remain process next image # If none remain process next image
if not pred.shape[0]: if not x.shape[0]:
continue continue
# Sort by confidence # Sort by confidence
# if method == 'fast_batch': # if method == 'fast_batch':
# pred = pred[pred[:, 4].argsort(descending=True)] # x = x[x[:, 4].argsort(descending=True)]
# Batched NMS # Batched NMS
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # classes c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
boxes, scores = pred[:, :4].clone(), pred[:, 4] boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
boxes += c.view(-1, 1) * max_wh # offset boxes by class if method == 'merge': # Merge NMS (boxes merged using weighted mean)
if method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
elif method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1) weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0) weights /= weights.sum(0)
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) x[i, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
i = iou.max(0)[0] < iou_thres i = iou.max(0)[0] < iou_thres
output[image_i] = pred[i] output[xi] = x[i]
def get_yolo_layers(model): def get_yolo_layers(model):