updates
This commit is contained in:
parent
cadd2f75ff
commit
0fe246f399
|
@ -1,4 +1,5 @@
|
||||||
import glob
|
import glob
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -10,8 +11,8 @@ import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import math
|
|
||||||
|
|
||||||
from . import torch_utils # , google_utils
|
from . import torch_utils # , google_utils
|
||||||
|
|
||||||
|
@ -503,7 +504,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
|
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||||
# pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551
|
|
||||||
|
|
||||||
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
|
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
|
||||||
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
|
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
|
||||||
|
@ -511,8 +511,21 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
pred = pred[(-pred[:, 4]).argsort()]
|
pred = pred[(-pred[:, 4]).argsort()]
|
||||||
|
|
||||||
|
# Set NMS method https://github.com/ultralytics/yolov3/issues/679
|
||||||
|
# 'OR', 'AND', 'MERGE', 'VISION', 'VISION_BATCHED'
|
||||||
|
method = 'MERGE' if conf_thres <= 0.01 else 'VISION' # MERGE is highest mAP, VISION is fastest
|
||||||
|
|
||||||
|
# Batched NMS
|
||||||
|
if method == 'VISION_BATCHED':
|
||||||
|
i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4],
|
||||||
|
scores=pred[:, 4],
|
||||||
|
idxs=pred[:, 6],
|
||||||
|
iou_threshold=nms_thres)
|
||||||
|
output[image_i] = pred[i]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Non-maximum suppression
|
||||||
det_max = []
|
det_max = []
|
||||||
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
|
||||||
for c in pred[:, -1].unique():
|
for c in pred[:, -1].unique():
|
||||||
dc = pred[pred[:, -1] == c] # select class c
|
dc = pred[pred[:, -1] == c] # select class c
|
||||||
n = len(dc)
|
n = len(dc)
|
||||||
|
@ -520,10 +533,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
det_max.append(dc) # No NMS required if only 1 prediction
|
det_max.append(dc) # No NMS required if only 1 prediction
|
||||||
continue
|
continue
|
||||||
elif n > 500:
|
elif n > 500:
|
||||||
dc = dc[:500] # limit to first 100 boxes: https://github.com/ultralytics/yolov3/issues/117
|
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
||||||
|
|
||||||
# Non-maximum suppression
|
if method == 'VISION':
|
||||||
if nms_style == 'OR': # default
|
i = torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres)
|
||||||
|
det_max.append(dc[i])
|
||||||
|
|
||||||
|
elif method == 'OR': # default
|
||||||
# METHOD1
|
# METHOD1
|
||||||
# ind = list(range(len(dc)))
|
# ind = list(range(len(dc)))
|
||||||
# while len(ind):
|
# while len(ind):
|
||||||
|
@ -540,14 +556,14 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||||
|
|
||||||
elif nms_style == 'AND': # requires overlap, single boxes erased
|
elif method == 'AND': # requires overlap, single boxes erased
|
||||||
while len(dc) > 1:
|
while len(dc) > 1:
|
||||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||||
if iou.max() > 0.5:
|
if iou.max() > 0.5:
|
||||||
det_max.append(dc[:1])
|
det_max.append(dc[:1])
|
||||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||||
|
|
||||||
elif nms_style == 'MERGE': # weighted mixture box
|
elif method == 'MERGE': # weighted mixture box
|
||||||
while len(dc):
|
while len(dc):
|
||||||
if len(dc) == 1:
|
if len(dc) == 1:
|
||||||
det_max.append(dc)
|
det_max.append(dc)
|
||||||
|
@ -558,7 +574,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
det_max.append(dc[:1])
|
det_max.append(dc[:1])
|
||||||
dc = dc[i == 0]
|
dc = dc[i == 0]
|
||||||
|
|
||||||
elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
|
elif method == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
|
||||||
sigma = 0.5 # soft-nms sigma parameter
|
sigma = 0.5 # soft-nms sigma parameter
|
||||||
while len(dc):
|
while len(dc):
|
||||||
if len(dc) == 1:
|
if len(dc) == 1:
|
||||||
|
|
Loading…
Reference in New Issue