initialize from darknet53

This commit is contained in:
Glenn Jocher 2018-10-30 15:18:52 +01:00
parent ed0390d0b5
commit 26c52f9485
1 changed files with 6 additions and 1 deletions

View File

@ -72,8 +72,13 @@ def main(opt):
best_loss = checkpoint['best_loss']
del checkpoint # current, saved
else:
load_weights(model, 'weights/darknet53.conv.74') # load darknet53 weights (optional)
# Initialize model with darknet53 weights (optional)
if not os.path.isfile('weights/darknet53.conv.74'):
os.system('wget https://pjreddie.com/media/files/darknet53.conv.74 -P /weights')
load_weights(model, 'weights/darknet53.conv.74')
if torch.cuda.device_count() > 1:
print('Using ', torch.cuda.device_count(), ' GPUs')
model = nn.DataParallel(model)