diff --git a/test.py b/test.py index 7bcef824..a0a71b14 100644 --- a/test.py +++ b/test.py @@ -31,7 +31,7 @@ num_classes = int(data_config['classes']) model = Darknet(opt.cfg, opt.img_size) # Load weights -if opt.weights_path('.weights'): # darknet format +if opt.weights_path.endswith('.weights'): # darknet format load_weights(model, opt.weights_path) elif opt.weights_path.endswith('.pt'): # pytorch format checkpoint = torch.load(opt.weights_path, map_location='cpu')