NMS and test batch_size updates
This commit is contained in:
parent
c6b59a0e8a
commit
eb151a881e
6
train.py
6
train.py
|
@ -180,12 +180,12 @@ def train():
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
# Testloader
|
# Testloader
|
||||||
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size * 2,
|
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size,
|
||||||
hyp=hyp,
|
hyp=hyp,
|
||||||
rect=True,
|
rect=True,
|
||||||
cache_images=opt.cache_images,
|
cache_images=opt.cache_images,
|
||||||
single_cls=opt.single_cls),
|
single_cls=opt.single_cls),
|
||||||
batch_size=batch_size * 2,
|
batch_size=batch_size,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
@ -311,7 +311,7 @@ def train():
|
||||||
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
||||||
results, maps = test.test(cfg,
|
results, maps = test.test(cfg,
|
||||||
data,
|
data,
|
||||||
batch_size=batch_size * 2,
|
batch_size=batch_size,
|
||||||
img_size=img_size_test,
|
img_size=img_size_test,
|
||||||
model=ema.ema,
|
model=ema.ema,
|
||||||
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
||||||
|
|
|
@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
x = x[torch.isfinite(x).all(1)]
|
x = x[torch.isfinite(x).all(1)]
|
||||||
|
|
||||||
# If none remain process next image
|
# If none remain process next image
|
||||||
if not x.shape[0]:
|
n = x.shape[0] # number of boxes
|
||||||
|
if not n:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sort by confidence
|
# Sort by confidence
|
||||||
|
@ -555,6 +556,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
|
if n < 1000: # update boxes
|
||||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||||
weights /= weights.sum(0)
|
weights /= weights.sum(0)
|
||||||
|
|
Loading…
Reference in New Issue