updates
This commit is contained in:
		
							parent
							
								
									a024286ec1
								
							
						
					
					
						commit
						d047062074
					
				
							
								
								
									
										2
									
								
								train.py
								
								
								
								
							
							
						
						
									
										2
									
								
								train.py
								
								
								
								
							|  | @ -51,7 +51,7 @@ def train( | ||||||
|     start_epoch = 0 |     start_epoch = 0 | ||||||
|     best_loss = float('inf') |     best_loss = float('inf') | ||||||
|     if resume:  # Load previously saved PyTorch model |     if resume:  # Load previously saved PyTorch model | ||||||
|         checkpoint = torch.load(latest, map_location=device)  # load checkpoin |         checkpoint = torch.load(latest, map_location=device)  # load checkpoint | ||||||
|         model.load_state_dict(checkpoint['model']) |         model.load_state_dict(checkpoint['model']) | ||||||
|         start_epoch = checkpoint['epoch'] + 1 |         start_epoch = checkpoint['epoch'] + 1 | ||||||
|         if checkpoint['optimizer'] is not None: |         if checkpoint['optimizer'] is not None: | ||||||
|  |  | ||||||
|  | @ -108,7 +108,7 @@ class LoadImagesAndLabels:  # for training | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         self.count = -1 |         self.count = -1 | ||||||
|         self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) |         #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|  | @ -133,8 +133,8 @@ class LoadImagesAndLabels:  # for training | ||||||
|     def load_images(self, ia, ib): |     def load_images(self, ia, ib): | ||||||
|         img_all, labels_all, img_paths, img_shapes = [], [], [], [] |         img_all, labels_all, img_paths, img_shapes = [], [], [], [] | ||||||
|         for index, files_index in enumerate(range(ia, ib)): |         for index, files_index in enumerate(range(ia, ib)): | ||||||
|             img_path = self.img_files[self.shuffled_vector[files_index]] |             img_path = self.img_files[files_index] | ||||||
|             label_path = self.label_files[self.shuffled_vector[files_index]] |             label_path = self.label_files[files_index] | ||||||
| 
 | 
 | ||||||
|             img = cv2.imread(img_path)  # BGR |             img = cv2.imread(img_path)  # BGR | ||||||
|             assert img is not None, 'File Not Found ' + img_path |             assert img is not None, 'File Not Found ' + img_path | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue