updates
This commit is contained in:
parent
e2a8f5bdce
commit
1d760a7046
10
test.py
10
test.py
|
@ -5,7 +5,6 @@ from utils.datasets import *
|
|||
from utils.utils import *
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-epochs', type=int, default=200, help='number of epochs')
|
||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
|
||||
|
@ -32,11 +31,10 @@ num_classes = int(data_config['classes'])
|
|||
model = Darknet(opt.cfg, opt.img_size)
|
||||
|
||||
# Load weights
|
||||
weights_path = 'checkpoints/yolov3.weights'
|
||||
if weights_path.endswith('.weights'): # darknet format
|
||||
load_weights(model, weights_path)
|
||||
elif weights_path.endswith('.pt'): # pytorch format
|
||||
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||
if opt.weights_path('.weights'): # darknet format
|
||||
load_weights(model, opt.weights_path)
|
||||
elif opt.weights_path.endswith('.pt'): # pytorch format
|
||||
checkpoint = torch.load(opt.weights_path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
del checkpoint
|
||||
|
||||
|
|
Loading…
Reference in New Issue