diff --git a/models.py b/models.py index 44904203..6183dbfc 100755 --- a/models.py +++ b/models.py @@ -266,8 +266,12 @@ class Darknet(nn.Module): return sum(output) if is_training else torch.cat(output, 1) -def load_weights(self, weights_path): - """Parses and loads the weights stored in 'weights_path'""" +def load_weights(self, weights_path, cutoff=-1): + # Parses and loads the weights stored in 'weights_path' + # @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved) + + if weights_path.endswith('darknet53.conv.74'): + cutoff = 75 # Open the weights file fp = open(weights_path, 'rb') @@ -281,7 +285,7 @@ def load_weights(self, weights_path): fp.close() ptr = 0 - for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)): + for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])): if module_def['type'] == 'convolutional': conv_layer = module[0] if module_def['batch_normalize']: diff --git a/train.py b/train.py index a1d8f7d1..fe9f56e1 100644 --- a/train.py +++ b/train.py @@ -73,6 +73,8 @@ def main(opt): del checkpoint # current, saved else: + load_weights(model, 'weights/darknet53.conv.74') # load darknet53 weights (optional) + if torch.cuda.device_count() > 1: print('Using ', torch.cuda.device_count(), ' GPUs') model = nn.DataParallel(model) diff --git a/utils/datasets.py b/utils/datasets.py index 0ecdfa5e..37982fbf 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -67,7 +67,7 @@ class load_images_and_labels(): # for training self.img_files = file.readlines() if platform == 'darwin': # MacOS (local) - self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images') + self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/data/coco/images') for path in self.img_files] else: # linux (gcp cloud) self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files] diff --git a/weights/download_yolov3_weights.sh b/weights/download_yolov3_weights.sh index c46580c5..aadcb453 100644 --- a/weights/download_yolov3_weights.sh +++ b/weights/download_yolov3_weights.sh @@ -1,4 +1,5 @@ #!/bin/bash +wget https://pjreddie.com/media/files/darknet53.conv.74 wget https://pjreddie.com/media/files/yolov3.weights wget https://storage.googleapis.com/ultralytics/yolov3.pt