diff --git a/test.py b/test.py index 9c2d9de4..7bcef824 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,6 @@ from utils.datasets import * from utils.utils import * parser = argparse.ArgumentParser() -parser.add_argument('-epochs', type=int, default=200, help='number of epochs') parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file') parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file') @@ -32,11 +31,10 @@ num_classes = int(data_config['classes']) model = Darknet(opt.cfg, opt.img_size) # Load weights -weights_path = 'checkpoints/yolov3.weights' -if weights_path.endswith('.weights'): # darknet format - load_weights(model, weights_path) -elif weights_path.endswith('.pt'): # pytorch format - checkpoint = torch.load(weights_path, map_location='cpu') +if opt.weights_path('.weights'): # darknet format + load_weights(model, opt.weights_path) +elif opt.weights_path.endswith('.pt'): # pytorch format + checkpoint = torch.load(opt.weights_path, map_location='cpu') model.load_state_dict(checkpoint['model']) del checkpoint