diff --git a/train.py b/train.py index a1f2fe1a..33adb52e 100644 --- a/train.py +++ b/train.py @@ -180,12 +180,12 @@ def train(): collate_fn=dataset.collate_fn) # Testloader - testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size * 2, + testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, img_size_test, batch_size, hyp=hyp, rect=True, cache_images=opt.cache_images, single_cls=opt.single_cls), - batch_size=batch_size * 2, + batch_size=batch_size, num_workers=nw, pin_memory=True, collate_fn=dataset.collate_fn) @@ -311,7 +311,7 @@ def train(): is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 results, maps = test.test(cfg, data, - batch_size=batch_size * 2, + batch_size=batch_size, img_size=img_size_test, model=ema.ema, conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed diff --git a/utils/utils.py b/utils/utils.py index 9eee669c..80b25d94 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -543,7 +543,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T x = x[torch.isfinite(x).all(1)] # If none remain process next image - if not x.shape[0]: + n = x.shape[0] # number of boxes + if not n: continue # Sort by confidence @@ -555,10 +556,11 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores if method == 'merge': # Merge NMS (boxes merged using weighted mean) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) - iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix - weights = (iou > iou_thres) * scores.view(-1, 1) - weights /= weights.sum(0) - x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) + if n < 1000: # update boxes + iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix + weights = (iou > iou_thres) * scores.view(-1, 1) + weights /= weights.sum(0) + x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) elif method == 'vision': i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) elif method == 'fast': # FastNMS from https://github.com/dbolya/yolact