This commit is contained in:
Glenn Jocher 2019-08-07 16:45:13 +02:00
parent 056976b4fc
commit f43170817c
2 changed files with 4 additions and 3 deletions

View File

@ -187,7 +187,8 @@ def train(cfg,
augment=True, augment=True,
hyp=hyp, # augmentation hyperparameters hyp=hyp, # augmentation hyperparameters
rect=opt.rect, # rectangular training rect=opt.rect, # rectangular training
image_weights=opt.img_weights) image_weights=opt.img_weights,
cache_images=opt.cache_images)
# Dataloader # Dataloader
dataloader = torch.utils.data.DataLoader(dataset, dataloader = torch.utils.data.DataLoader(dataset,
@ -352,6 +353,7 @@ if __name__ == '__main__':
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--img-weights', action='store_true', help='select training images by weight') parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)

View File

@ -156,7 +156,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False): def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, cache_images=False):
path = str(Path(path)) # os-agnostic path = str(Path(path)) # os-agnostic
with open(path, 'r') as f: with open(path, 'r') as f:
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
@ -254,7 +254,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
assert nf > 0, 'No labels found. Recommend correcting image and label paths.' assert nf > 0, 'No labels found. Recommend correcting image and label paths.'
# Cache images into memory for faster training (~5GB) # Cache images into memory for faster training (~5GB)
cache_images = False
if cache_images and augment: # if training if cache_images and augment: # if training
for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'): # max 10k images for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'): # max 10k images
img_path = self.img_files[i] img_path = self.img_files[i]