diff --git a/test.py b/test.py index 843f077b..c16790e0 100644 --- a/test.py +++ b/test.py @@ -6,7 +6,6 @@ from torch.utils.data import DataLoader from models import * from utils.datasets import * from utils.utils import * -from utils import nms def test(cfg, @@ -76,7 +75,7 @@ def test(cfg, loss += compute_loss(train_out, targets, model)[0].item() # Run NMS - output = nms.multiprocess_nms(inf_out, conf_thres=conf_thres, nms_thres=nms_thres) + output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres) # Statistics per image for si, pred in enumerate(output): @@ -192,7 +191,7 @@ def test(cfg, if __name__ == '__main__': parser = argparse.ArgumentParser(prog='test.py') - parser.add_argument('--batch-size', type=int, default=4, help='size of each image batch') + parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch') parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') parser.add_argument('--data', type=str, default='data/coco.data', help='coco.data file path') parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file')