updates
This commit is contained in:
		
							parent
							
								
									056976b4fc
								
							
						
					
					
						commit
						f43170817c
					
				
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -187,7 +187,8 @@ def train(cfg, | ||||||
|                                   augment=True, |                                   augment=True, | ||||||
|                                   hyp=hyp,  # augmentation hyperparameters |                                   hyp=hyp,  # augmentation hyperparameters | ||||||
|                                   rect=opt.rect,  # rectangular training |                                   rect=opt.rect,  # rectangular training | ||||||
|                                   image_weights=opt.img_weights) |                                   image_weights=opt.img_weights, | ||||||
|  |                                   cache_images=opt.cache_images) | ||||||
| 
 | 
 | ||||||
|     # Dataloader |     # Dataloader | ||||||
|     dataloader = torch.utils.data.DataLoader(dataset, |     dataloader = torch.utils.data.DataLoader(dataset, | ||||||
|  | @ -352,6 +353,7 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') |     parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') | ||||||
|     parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') |     parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') | ||||||
|     parser.add_argument('--img-weights', action='store_true', help='select training images by weight') |     parser.add_argument('--img-weights', action='store_true', help='select training images by weight') | ||||||
|  |     parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     print(opt) |     print(opt) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -156,7 +156,7 @@ class LoadWebcam:  # for inference | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LoadImagesAndLabels(Dataset):  # for training/testing | class LoadImagesAndLabels(Dataset):  # for training/testing | ||||||
|     def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False): |     def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, cache_images=False): | ||||||
|         path = str(Path(path))  # os-agnostic |         path = str(Path(path))  # os-agnostic | ||||||
|         with open(path, 'r') as f: |         with open(path, 'r') as f: | ||||||
|             self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines()  # os-agnostic |             self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines()  # os-agnostic | ||||||
|  | @ -254,7 +254,6 @@ class LoadImagesAndLabels(Dataset):  # for training/testing | ||||||
|             assert nf > 0, 'No labels found. Recommend correcting image and label paths.' |             assert nf > 0, 'No labels found. Recommend correcting image and label paths.' | ||||||
| 
 | 
 | ||||||
|         # Cache images into memory for faster training (~5GB) |         # Cache images into memory for faster training (~5GB) | ||||||
|         cache_images = False |  | ||||||
|         if cache_images and augment:  # if training |         if cache_images and augment:  # if training | ||||||
|             for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'):  # max 10k images |             for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'):  # max 10k images | ||||||
|                 img_path = self.img_files[i] |                 img_path = self.img_files[i] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue