updates
This commit is contained in:
parent
e4d62de5bc
commit
a80b2d1611
|
@ -369,7 +369,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
if prediction.is_cuda:
|
||||
unique_labels = unique_labels.cuda(prediction.device)
|
||||
|
||||
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||
for c in unique_labels:
|
||||
# Get the detections with class c
|
||||
dc = detections[detections[:, -1] == c]
|
||||
|
@ -387,6 +387,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
# Image Total P R mAP
|
||||
# 32 5000 0.633 0.579 0.568
|
||||
# 64 5000 0.619 0.579 0.568
|
||||
# 96 5000 0.652 0.622 0.613
|
||||
# 128 5000 0.651 0.625 0.617
|
||||
|
||||
elif nms_style == 'AND': # requires overlap, single boxes erased
|
||||
while len(dc) > 1:
|
||||
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
|
||||
|
@ -396,10 +402,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
|
||||
elif nms_style == 'MERGE': # weighted mixture box
|
||||
while len(dc) > 0:
|
||||
if len(dc) == 1: # Stop if we're at the last detection
|
||||
det_max.append(dc[:1]) # save highest conf detection
|
||||
break
|
||||
iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
|
||||
iou = bbox_iou(dc[:1], dc[0:]) # iou with other boxes
|
||||
i = iou > nms_thres
|
||||
|
||||
weights = dc[i, 4:5] * dc[i, 5:6]
|
||||
dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
|
||||
det_max.append(dc[:1])
|
||||
dc = dc[iou < nms_thres]
|
||||
|
||||
# Image Total P R mAP
|
||||
# 32 5000 0.635 0.581 0.569
|
||||
# 64 5000 0.63 0.591 0.578
|
||||
# 96 5000 0.66 0.63 0.62
|
||||
# 128 5000 0.657 0.631 0.622
|
||||
|
||||
if len(det_max) > 0:
|
||||
det_max = torch.cat(det_max)
|
||||
|
|
Loading…
Reference in New Issue