updates
This commit is contained in:
parent
dc704edf17
commit
10cca39934
|
@ -1,6 +1,6 @@
|
|||
classes=80
|
||||
train=/Users/glennjocher/Downloads/DATA/coco/trainvalno5k.txt
|
||||
valid=/Users/glennjocher/Downloads/DATA/coco/5k.txt
|
||||
train=../coco/trainvalno5k.txt
|
||||
valid=../coco/5k.txt
|
||||
names=data/coco.names
|
||||
backup=backup/
|
||||
eval=coco
|
||||
|
|
5
train.py
5
train.py
|
@ -44,10 +44,7 @@ def main(opt):
|
|||
# Configure run
|
||||
data_config = parse_data_config(opt.data_config_path)
|
||||
num_classes = int(data_config['classes'])
|
||||
if platform == 'darwin': # MacOS (local)
|
||||
train_path = data_config['train']
|
||||
else: # linux (cloud, i.e. gcp)
|
||||
train_path = '../coco/trainvalno5k.part'
|
||||
train_path = '../coco/trainvalno5k.txt'
|
||||
|
||||
# Initialize model
|
||||
model = Darknet(opt.cfg, opt.img_size)
|
||||
|
|
|
@ -66,12 +66,7 @@ class load_images_and_labels(): # for training
|
|||
with open(path, 'r') as file:
|
||||
self.img_files = file.readlines()
|
||||
|
||||
if platform == 'darwin': # MacOS (local)
|
||||
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/data/coco/images')
|
||||
for path in self.img_files]
|
||||
else: # linux (gcp cloud)
|
||||
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
|
||||
|
||||
self.img_files = [path.replace('\n', '') for path in self.img_files]
|
||||
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
|
||||
self.img_files]
|
||||
|
||||
|
@ -287,7 +282,7 @@ def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scal
|
|||
return imw
|
||||
|
||||
|
||||
def convert_tif2bmp(p='/Users/glennjocher/Downloads/DATA/xview/val_images_bmp'):
|
||||
def convert_tif2bmp(p='../xview/val_images_bmp'):
|
||||
import glob
|
||||
import cv2
|
||||
files = sorted(glob.glob('%s/*.tif' % p))
|
||||
|
|
|
@ -424,7 +424,7 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
|||
torch.save(a, filename.replace('.pt', '_lite.pt'))
|
||||
|
||||
|
||||
def coco_class_count(path='/Users/glennjocher/downloads/DATA/coco/labels/train2014/'):
|
||||
def coco_class_count(path='../coco/labels/train2014/'):
|
||||
import glob
|
||||
|
||||
nC = 80 # number classes
|
||||
|
|
Loading…
Reference in New Issue