From e4d62de5bc12d1e411adbfe4b76f15d157d77c65 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 18 Feb 2019 18:32:31 +0100 Subject: [PATCH] updates --- models.py | 2 +- utils/utils.py | 56 +++++++++++++++++++++++--------------------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/models.py b/models.py index c8b0ed1a..e1fe21cb 100755 --- a/models.py +++ b/models.py @@ -146,7 +146,7 @@ class YOLOLayer(nn.Module): def forward(self, p, targets=None, var=None): bs = 1 if ONNX_EXPORT else p.shape[0] # batch size - nG = self.nG # number of grid points + nG = self.nG if ONNX_EXPORT else p.shape[-1] # number of grid points if p.is_cuda and not self.weights.is_cuda: self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() diff --git a/utils/utils.py b/utils/utils.py index b5666268..262c89c7 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -369,44 +369,40 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): if prediction.is_cuda: unique_labels = unique_labels.cuda(prediction.device) - nms_style = 'OR' # 'AND', 'OR' (classical), 'MERGE' (experimental) + nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in unique_labels: - # Get the detections with the particular class - det_class = detections[detections[:, -1] == c] - # Sort the detections by maximum objectness confidence - _, conf_sort_index = torch.sort(det_class[:, 4], descending=True) - det_class = det_class[conf_sort_index] - # Perform non-maximum suppression + # Get the detections with class c + dc = detections[detections[:, -1] == c] + # Sort the detections by maximum object confidence + _, conf_sort_index = torch.sort(dc[:, 4], descending=True) + dc = dc[conf_sort_index] + + # Non-maximum suppression det_max = [] - - if nms_style == 'OR': # Classical NMS - while det_class.shape[0]: - # Get detection with highest confidence and save as max detection - det_max.append(det_class[0].unsqueeze(0)) - # Stop if we're at the last detection - if len(det_class) == 1: + if nms_style == 'OR': # default + while dc.shape[0]: + det_max.append(dc[:1]) # save highest conf detection + if len(dc) == 1: # Stop if we're at the last detection break - # Get the IOUs for all boxes with lower confidence - ious = bbox_iou(det_max[-1], det_class[1:]) + iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes + dc = dc[1:][iou < nms_thres] # remove ious > threshold - # Remove detections with IoU >= NMS threshold - det_class = det_class[1:][ious < nms_thres] + elif nms_style == 'AND': # requires overlap, single boxes erased + while len(dc) > 1: + iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes + if iou.max() > 0.5: + det_max.append(dc[:1]) + dc = dc[1:][iou < nms_thres] # remove ious > threshold - elif nms_style == 'AND': # 'AND'-style NMS: >=2 boxes must share commonality to pass, single boxes erased - while det_class.shape[0]: - if len(det_class) == 1: + elif nms_style == 'MERGE': # weighted mixture box + while len(dc) > 0: + if len(dc) == 1: # Stop if we're at the last detection + det_max.append(dc[:1]) # save highest conf detection break - - ious = bbox_iou(det_class[:1], det_class[1:]) - - if ious.max() > 0.5: - det_max.append(det_class[0].unsqueeze(0)) - - # Remove detections with IoU >= NMS threshold - det_class = det_class[1:][ious < nms_thres] + iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes if len(det_max) > 0: - det_max = torch.cat(det_max).data + det_max = torch.cat(det_max) # Add max detections to outputs output[image_i] = det_max if output[image_i] is None else torch.cat((output[image_i], det_max))