This commit is contained in:
Glenn Jocher 2019-12-02 18:22:21 -08:00
parent cadd2f75ff
commit 0fe246f399
1 changed files with 25 additions and 9 deletions

View File

@ -1,4 +1,5 @@
import glob import glob
import math
import os import os
import random import random
import shutil import shutil
@ -10,8 +11,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision
from tqdm import tqdm from tqdm import tqdm
import math
from . import torch_utils # , google_utils from . import torch_utils # , google_utils
@ -503,7 +504,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
pred[:, :4] = xywh2xyxy(pred[:, :4]) pred[:, :4] = xywh2xyxy(pred[:, :4])
# pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred) # Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1) pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
@ -511,8 +511,21 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
# Get detections sorted by decreasing confidence scores # Get detections sorted by decreasing confidence scores
pred = pred[(-pred[:, 4]).argsort()] pred = pred[(-pred[:, 4]).argsort()]
# Set NMS method https://github.com/ultralytics/yolov3/issues/679
# 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
# Batched NMS
if method == 'VISION_BATCHED':
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
scores=pred[:, 4],
idxs=pred[:, 6],
iou_threshold=nms_thres)
output[image_i] = pred[i]
continue
# Non-maximum suppression
det_max = [] det_max = []
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
for c in pred[:, -1].unique(): for c in pred[:, -1].unique():
dc = pred[pred[:, -1] == c] # select class c dc = pred[pred[:, -1] == c] # select class c
n = len(dc) n = len(dc)
@ -520,10 +533,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
det_max.append(dc) # No NMS required if only 1 prediction det_max.append(dc) # No NMS required if only 1 prediction
continue continue
elif n > 500: elif n > 500:
dc = dc[:500] # limit to first 100 boxes: https://github.com/ultralytics/yolov3/issues/117 dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
# Non-maximum suppression if method == 'VISION':
if nms_style == 'OR': # default i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres)
det_max.append(dc[i])
elif method == 'OR': # default
# METHOD1 # METHOD1
# ind = list(range(len(dc))) # ind = list(range(len(dc)))
# while len(ind): # while len(ind):
@ -540,14 +556,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
dc = dc[1:][iou < nms_thres] # remove ious > threshold dc = dc[1:][iou < nms_thres] # remove ious > threshold
elif nms_style == 'AND': # requires overlap, single boxes erased elif method == 'AND': # requires overlap, single boxes erased
while len(dc) > 1: while len(dc) > 1:
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
if iou.max() > 0.5: if iou.max() > 0.5:
det_max.append(dc[:1]) det_max.append(dc[:1])
dc = dc[1:][iou < nms_thres] # remove ious > threshold dc = dc[1:][iou < nms_thres] # remove ious > threshold
elif nms_style == 'MERGE': # weighted mixture box elif method == 'MERGE': # weighted mixture box
while len(dc): while len(dc):
if len(dc) == 1: if len(dc) == 1:
det_max.append(dc) det_max.append(dc)
@ -558,7 +574,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
det_max.append(dc[:1]) det_max.append(dc[:1])
dc = dc[i == 0] dc = dc[i == 0]
elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503 elif method == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
sigma = 0.5 # soft-nms sigma parameter sigma = 0.5 # soft-nms sigma parameter
while len(dc): while len(dc):
if len(dc) == 1: if len(dc) == 1: