rename /checkpoints to /weights
This commit is contained in:
parent
0ae90d0fb7
commit
332fe002b3
10
models.py
10
models.py
|
@ -266,8 +266,12 @@ class Darknet(nn.Module):
|
||||||
return sum(output) if is_training else torch.cat(output, 1)
|
return sum(output) if is_training else torch.cat(output, 1)
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, weights_path):
|
def load_weights(self, weights_path, cutoff=-1):
|
||||||
"""Parses and loads the weights stored in 'weights_path'"""
|
# Parses and loads the weights stored in 'weights_path'
|
||||||
|
# @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
|
||||||
|
|
||||||
|
if weights_path.endswith('darknet53.conv.74'):
|
||||||
|
cutoff = 75
|
||||||
|
|
||||||
# Open the weights file
|
# Open the weights file
|
||||||
fp = open(weights_path, 'rb')
|
fp = open(weights_path, 'rb')
|
||||||
|
@ -281,7 +285,7 @@ def load_weights(self, weights_path):
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
ptr = 0
|
ptr = 0
|
||||||
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
|
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
|
||||||
if module_def['type'] == 'convolutional':
|
if module_def['type'] == 'convolutional':
|
||||||
conv_layer = module[0]
|
conv_layer = module[0]
|
||||||
if module_def['batch_normalize']:
|
if module_def['batch_normalize']:
|
||||||
|
|
2
train.py
2
train.py
|
@ -73,6 +73,8 @@ def main(opt):
|
||||||
|
|
||||||
del checkpoint # current, saved
|
del checkpoint # current, saved
|
||||||
else:
|
else:
|
||||||
|
load_weights(model, 'weights/darknet53.conv.74') # load darknet53 weights (optional)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
print('Using ', torch.cuda.device_count(), ' GPUs')
|
print('Using ', torch.cuda.device_count(), ' GPUs')
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
|
@ -67,7 +67,7 @@ class load_images_and_labels(): # for training
|
||||||
self.img_files = file.readlines()
|
self.img_files = file.readlines()
|
||||||
|
|
||||||
if platform == 'darwin': # MacOS (local)
|
if platform == 'darwin': # MacOS (local)
|
||||||
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
|
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/data/coco/images')
|
||||||
for path in self.img_files]
|
for path in self.img_files]
|
||||||
else: # linux (gcp cloud)
|
else: # linux (gcp cloud)
|
||||||
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
|
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
wget https://pjreddie.com/media/files/darknet53.conv.74
|
||||||
wget https://pjreddie.com/media/files/yolov3.weights
|
wget https://pjreddie.com/media/files/yolov3.weights
|
||||||
wget https://storage.googleapis.com/ultralytics/yolov3.pt
|
wget https://storage.googleapis.com/ultralytics/yolov3.pt
|
||||||
|
|
Loading…
Reference in New Issue