Merge NMS update
This commit is contained in:
parent
eac07f9da3
commit
171b4129b5
|
@ -495,78 +495,75 @@ def build_targets(model, targets):
|
|||
|
||||
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
||||
"""
|
||||
Removes detections with lower object confidence score than 'conf_thres'
|
||||
Non-Maximum Suppression to further filter detections.
|
||||
Performs Non-Maximum Suppression on inference results
|
||||
Returns detections with shape:
|
||||
(x1, y1, x2, y2, object_conf, conf, class)
|
||||
nx6 (x1, y1, x2, y2, conf, cls)
|
||||
"""
|
||||
# NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'
|
||||
|
||||
# Box constraints
|
||||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||
|
||||
method = 'vision'
|
||||
method = 'merge'
|
||||
nc = prediction[0].shape[1] - 5 # number of classes
|
||||
multi_label &= nc > 1 # multiple labels per box
|
||||
output = [None] * len(prediction)
|
||||
for image_i, pred in enumerate(prediction):
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply conf constraint
|
||||
pred = pred[pred[:, 4] > conf_thres]
|
||||
x = x[x[:, 4] > conf_thres]
|
||||
|
||||
# Apply width-height constraint
|
||||
pred = pred[((pred[:, 2:4] > min_wh) & (pred[:, 2:4] < max_wh)).all(1)]
|
||||
x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)]
|
||||
|
||||
# If none remain process next image
|
||||
if not pred.shape[0]:
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Compute conf
|
||||
pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf
|
||||
x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
box = xywh2xyxy(pred[:, :4])
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
if multi_label:
|
||||
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
|
||||
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
i, j = (x[:, 5:] > conf_thres).nonzero().t()
|
||||
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
else: # best class only
|
||||
conf, j = pred[:, 5:].max(1)
|
||||
pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
conf, j = x[:, 5:].max(1)
|
||||
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)
|
||||
|
||||
# Filter by class
|
||||
if classes:
|
||||
pred = pred[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
|
||||
x = x[(j.view(-1, 1) == torch.tensor(classes, device=j.device)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
if not torch.isfinite(pred).all():
|
||||
pred = pred[torch.isfinite(pred).all(1)]
|
||||
if not torch.isfinite(x).all():
|
||||
x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# If none remain process next image
|
||||
if not pred.shape[0]:
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
# if method == 'fast_batch':
|
||||
# pred = pred[pred[:, 4].argsort(descending=True)]
|
||||
# x = x[x[:, 4].argsort(descending=True)]
|
||||
|
||||
# Batched NMS
|
||||
c = pred[:, 5] * 0 if agnostic else pred[:, 5] # classes
|
||||
boxes, scores = pred[:, :4].clone(), pred[:, 4]
|
||||
boxes += c.view(-1, 1) * max_wh # offset boxes by class
|
||||
if method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
elif method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
|
||||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
iou = box_iou(boxes, boxes[i]).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
x[i, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
elif method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
||||
i = iou.max(0)[0] < iou_thres
|
||||
|
||||
output[image_i] = pred[i]
|
||||
output[xi] = x[i]
|
||||
|
||||
|
||||
def get_yolo_layers(model):
|
||||
|
|
Loading…
Reference in New Issue