updates
This commit is contained in:
parent
6e2cf074a1
commit
e4d62de5bc
|
@ -146,7 +146,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
def forward(self, p, targets=None, var=None):
|
||||
bs = 1 if ONNX_EXPORT else p.shape[0] # batch size
|
||||
nG = self.nG # number of grid points
|
||||
nG = self.nG if ONNX_EXPORT else p.shape[-1] # number of grid points
|
||||
|
||||
if p.is_cuda and not self.weights.is_cuda:
|
||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||
|
|
|
@ -369,44 +369,40 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
if prediction.is_cuda:
|
||||
unique_labels = unique_labels.cuda(prediction.device)
|
||||
|
||||
nms_style = 'OR' # 'AND', 'OR' (classical), 'MERGE' (experimental)
|
||||
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||
for c in unique_labels:
|
||||
# Get the detections with the particular class
|
||||
det_class = detections[detections[:, -1] == c]
|
||||
# Sort the detections by maximum objectness confidence
|
||||
_, conf_sort_index = torch.sort(det_class[:, 4], descending=True)
|
||||
det_class = det_class[conf_sort_index]
|
||||
# Perform non-maximum suppression
|
||||
# Get the detections with class c
|
||||
dc = detections[detections[:, -1] == c]
|
||||
# Sort the detections by maximum object confidence
|
||||
_, conf_sort_index = torch.sort(dc[:, 4], descending=True)
|
||||
dc = dc[conf_sort_index]
|
||||
|
||||
# Non-maximum suppression
|
||||
det_max = []
|
||||
|
||||
if nms_style == 'OR': # Classical NMS
|
||||
while det_class.shape[0]:
|
||||
# Get detection with highest confidence and save as max detection
|
||||
det_max.append(det_class[0].unsqueeze(0))
|
||||
# Stop if we're at the last detection
|
||||
if len(det_class) == 1:
|
||||
if nms_style == 'OR': # default
|
||||
while dc.shape[0]:
|
||||
det_max.append(dc[:1]) # save highest conf detection
|
||||
if len(dc) == 1: # Stop if we're at the last detection
|
||||
break
|
||||
# Get the IOUs for all boxes with lower confidence
|
||||
ious = bbox_iou(det_max[-1], det_class[1:])
|
||||
iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
# Remove detections with IoU >= NMS threshold
|
||||
det_class = det_class[1:][ious < nms_thres]
|
||||
elif nms_style == 'AND': # requires overlap, single boxes erased
|
||||
while len(dc) > 1:
|
||||
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
|
||||
if iou.max() > 0.5:
|
||||
det_max.append(dc[:1])
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
elif nms_style == 'AND': # 'AND'-style NMS: >=2 boxes must share commonality to pass, single boxes erased
|
||||
while det_class.shape[0]:
|
||||
if len(det_class) == 1:
|
||||
elif nms_style == 'MERGE': # weighted mixture box
|
||||
while len(dc) > 0:
|
||||
if len(dc) == 1: # Stop if we're at the last detection
|
||||
det_max.append(dc[:1]) # save highest conf detection
|
||||
break
|
||||
|
||||
ious = bbox_iou(det_class[:1], det_class[1:])
|
||||
|
||||
if ious.max() > 0.5:
|
||||
det_max.append(det_class[0].unsqueeze(0))
|
||||
|
||||
# Remove detections with IoU >= NMS threshold
|
||||
det_class = det_class[1:][ious < nms_thres]
|
||||
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
|
||||
|
||||
if len(det_max) > 0:
|
||||
det_max = torch.cat(det_max).data
|
||||
det_max = torch.cat(det_max)
|
||||
# Add max detections to outputs
|
||||
output[image_i] = det_max if output[image_i] is None else torch.cat((output[image_i], det_max))
|
||||
|
||||
|
|
Loading…
Reference in New Issue