NMS and test batch_size updates

This commit is contained in:
Glenn Jocher 2020-03-29 20:41:32 -07:00
parent c6b59a0e8a
commit eb151a881e
2 changed files with 10 additions and 8 deletions

View File

@ -180,12 +180,12 @@ def train():
collate_fn=dataset.collate_fn)
# Testloader
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size * 2,
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size,
hyp=hyp,
rect=True,
cache_images=opt.cache_images,
single_cls=opt.single_cls),
batch_size=batch_size * 2,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=dataset.collate_fn)
@ -311,7 +311,7 @@ def train():
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
results, maps = test.test(cfg,
data,
batch_size=batch_size * 2,
batch_size=batch_size,
img_size=img_size_test,
model=ema.ema,
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed

View File

@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
x = x[torch.isfinite(x).all(1)]
# If none remain process next image
if not x.shape[0]:
n = x.shape[0] # number of boxes
if not n:
continue
# Sort by confidence
@ -555,10 +556,11 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0)
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
if n < 1000: # update boxes
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0)
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact