This commit is contained in:
Glenn Jocher 2019-02-18 19:13:40 +01:00
parent e4d62de5bc
commit a80b2d1611
1 changed files with 20 additions and 5 deletions

View File

@ -369,7 +369,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if prediction.is_cuda: if prediction.is_cuda:
unique_labels = unique_labels.cuda(prediction.device) unique_labels = unique_labels.cuda(prediction.device)
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental) nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
for c in unique_labels: for c in unique_labels:
# Get the detections with class c # Get the detections with class c
dc = detections[detections[:, -1] == c] dc = detections[detections[:, -1] == c]
@ -387,6 +387,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes
dc = dc[1:][iou < nms_thres] # remove ious > threshold dc = dc[1:][iou < nms_thres] # remove ious > threshold
# Image Total P R mAP
# 32 5000 0.633 0.579 0.568
# 64 5000 0.619 0.579 0.568
# 96 5000 0.652 0.622 0.613
# 128 5000 0.651 0.625 0.617
elif nms_style == 'AND': # requires overlap, single boxes erased elif nms_style == 'AND': # requires overlap, single boxes erased
while len(dc) > 1: while len(dc) > 1:
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
@ -396,10 +402,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
elif nms_style == 'MERGE': # weighted mixture box elif nms_style == 'MERGE': # weighted mixture box
while len(dc) > 0: while len(dc) > 0:
if len(dc) == 1: # Stop if we're at the last detection iou = bbox_iou(dc[:1], dc[0:]) # iou with other boxes
det_max.append(dc[:1]) # save highest conf detection i = iou > nms_thres
break
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes weights = dc[i, 4:5] * dc[i, 5:6]
dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
det_max.append(dc[:1])
dc = dc[iou < nms_thres]
# Image Total P R mAP
# 32 5000 0.635 0.581 0.569
# 64 5000 0.63 0.591 0.578
# 96 5000 0.66 0.63 0.62
# 128 5000 0.657 0.631 0.622
if len(det_max) > 0: if len(det_max) > 0:
det_max = torch.cat(det_max) det_max = torch.cat(det_max)