This commit is contained in:
Glenn Jocher 2019-02-18 19:21:21 +01:00
parent adea337545
commit 77ce2cd43f
1 changed files with 4 additions and 9 deletions

View File

@ -388,10 +388,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
dc = dc[1:][iou < nms_thres] # remove ious > threshold dc = dc[1:][iou < nms_thres] # remove ious > threshold
# Image Total P R mAP # Image Total P R mAP
# 32 5000 0.633 0.579 0.568
# 64 5000 0.619 0.579 0.568
# 96 5000 0.652 0.622 0.613
# 128 5000 0.651 0.625 0.617
# 5000 5000 0.627 0.593 0.584 # 5000 5000 0.627 0.593 0.584
elif nms_style == 'AND': # requires overlap, single boxes erased elif nms_style == 'AND': # requires overlap, single boxes erased
@ -406,16 +402,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
iou = bbox_iou(dc[:1], dc[0:]) # iou with other boxes iou = bbox_iou(dc[:1], dc[0:]) # iou with other boxes
i = iou > nms_thres i = iou > nms_thres
weights = (dc[i, 4:5] * dc[i, 5:6]) ** 0.5 weights = (dc[i, 4:5] * dc[i, 5:6]) ** 2
dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum() dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
det_max.append(dc[:1]) det_max.append(dc[:1])
dc = dc[iou < nms_thres] dc = dc[iou < nms_thres]
# Image Total P R mAP # Image Total P R mAP
# 32 5000 0.635 0.581 0.569 # 4964 5000 0.632 0.597 0.588 # normal
# 64 5000 0.63 0.591 0.578 # 4964 5000 0.632 0.597 0.588 # squared
# 96 5000 0.66 0.63 0.62 # 4964 5000 0.631 0.597 0.588 # sqrt
# 128 5000 0.657 0.631 0.622
if len(det_max) > 0: if len(det_max) > 0:
det_max = torch.cat(det_max) det_max = torch.cat(det_max)