rename /checkpoints to /weights

This commit is contained in:
Glenn Jocher 2018-10-30 14:58:26 +01:00
parent 0ae90d0fb7
commit 332fe002b3
4 changed files with 11 additions and 4 deletions

View File

@ -266,8 +266,12 @@ class Darknet(nn.Module):
return sum(output) if is_training else torch.cat(output, 1) return sum(output) if is_training else torch.cat(output, 1)
def load_weights(self, weights_path): def load_weights(self, weights_path, cutoff=-1):
"""Parses and loads the weights stored in 'weights_path'""" # Parses and loads the weights stored in 'weights_path'
# @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
if weights_path.endswith('darknet53.conv.74'):
cutoff = 75
# Open the weights file # Open the weights file
fp = open(weights_path, 'rb') fp = open(weights_path, 'rb')
@ -281,7 +285,7 @@ def load_weights(self, weights_path):
fp.close() fp.close()
ptr = 0 ptr = 0
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)): for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
if module_def['type'] == 'convolutional': if module_def['type'] == 'convolutional':
conv_layer = module[0] conv_layer = module[0]
if module_def['batch_normalize']: if module_def['batch_normalize']:

View File

@ -73,6 +73,8 @@ def main(opt):
del checkpoint # current, saved del checkpoint # current, saved
else: else:
load_weights(model, 'weights/darknet53.conv.74') # load darknet53 weights (optional)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
print('Using ', torch.cuda.device_count(), ' GPUs') print('Using ', torch.cuda.device_count(), ' GPUs')
model = nn.DataParallel(model) model = nn.DataParallel(model)

View File

@ -67,7 +67,7 @@ class load_images_and_labels(): # for training
self.img_files = file.readlines() self.img_files = file.readlines()
if platform == 'darwin': # MacOS (local) if platform == 'darwin': # MacOS (local)
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images') self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/data/coco/images')
for path in self.img_files] for path in self.img_files]
else: # linux (gcp cloud) else: # linux (gcp cloud)
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files] self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]

View File

@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
wget https://pjreddie.com/media/files/darknet53.conv.74
wget https://pjreddie.com/media/files/yolov3.weights wget https://pjreddie.com/media/files/yolov3.weights
wget https://storage.googleapis.com/ultralytics/yolov3.pt wget https://storage.googleapis.com/ultralytics/yolov3.pt