NMS and test batch_size updates
This commit is contained in:
parent
c6b59a0e8a
commit
eb151a881e
6
train.py
6
train.py
|
@ -180,12 +180,12 @@ def train():
|
|||
collate_fn=dataset.collate_fn)
|
||||
|
||||
# Testloader
|
||||
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size * 2,
|
||||
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size,
|
||||
hyp=hyp,
|
||||
rect=True,
|
||||
cache_images=opt.cache_images,
|
||||
single_cls=opt.single_cls),
|
||||
batch_size=batch_size * 2,
|
||||
batch_size=batch_size,
|
||||
num_workers=nw,
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
|
@ -311,7 +311,7 @@ def train():
|
|||
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
||||
results, maps = test.test(cfg,
|
||||
data,
|
||||
batch_size=batch_size * 2,
|
||||
batch_size=batch_size,
|
||||
img_size=img_size_test,
|
||||
model=ema.ema,
|
||||
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
||||
|
|
|
@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n:
|
||||
continue
|
||||
|
||||
# Sort by confidence
|
||||
|
@ -555,10 +556,11 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
if n < 1000: # update boxes
|
||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
elif method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact
|
||||
|
|
Loading…
Reference in New Issue