diff --git a/utils/utils.py b/utils/utils.py index 262c89c7..6c21bfea 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -369,7 +369,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): if prediction.is_cuda: unique_labels = unique_labels.cuda(prediction.device) - nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental) + nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in unique_labels: # Get the detections with class c dc = detections[detections[:, -1] == c] @@ -387,6 +387,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes dc = dc[1:][iou < nms_thres] # remove ious > threshold + # Image Total P R mAP + # 32 5000 0.633 0.579 0.568 + # 64 5000 0.619 0.579 0.568 + # 96 5000 0.652 0.622 0.613 + # 128 5000 0.651 0.625 0.617 + elif nms_style == 'AND': # requires overlap, single boxes erased while len(dc) > 1: iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes @@ -396,10 +402,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): elif nms_style == 'MERGE': # weighted mixture box while len(dc) > 0: - if len(dc) == 1: # Stop if we're at the last detection - det_max.append(dc[:1]) # save highest conf detection - break - iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes + iou = bbox_iou(dc[:1], dc[0:]) # iou with other boxes + i = iou > nms_thres + + weights = dc[i, 4:5] * dc[i, 5:6] + dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum() + det_max.append(dc[:1]) + dc = dc[iou < nms_thres] + + # Image Total P R mAP + # 32 5000 0.635 0.581 0.569 + # 64 5000 0.63 0.591 0.578 + # 96 5000 0.66 0.63 0.62 + # 128 5000 0.657 0.631 0.622 if len(det_max) > 0: det_max = torch.cat(det_max)