updates
This commit is contained in:
		
							parent
							
								
									671747318d
								
							
						
					
					
						commit
						f20a03e28e
					
				
							
								
								
									
										3
									
								
								train.py
								
								
								
								
							
							
						
						
									
										3
									
								
								train.py
								
								
								
								
							|  | @ -381,11 +381,12 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture')  # defaultpw, uCE, uBCE |     parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture')  # defaultpw, uCE, uBCE | ||||||
|     parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training') |     parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training') | ||||||
|     parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') |     parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') | ||||||
|  |     parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1') | ||||||
|     parser.add_argument('--var', type=float, help='debug variable') |     parser.add_argument('--var', type=float, help='debug variable') | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     opt.weights = 'weights/last.pt' if opt.resume else opt.weights |     opt.weights = 'weights/last.pt' if opt.resume else opt.weights | ||||||
|     print(opt) |     print(opt) | ||||||
|     device = torch_utils.select_device(apex=mixed_precision) |     device = torch_utils.select_device(opt.device, apex=mixed_precision) | ||||||
| 
 | 
 | ||||||
|     tb_writer = None |     tb_writer = None | ||||||
|     if not opt.evolve:  # Train normally |     if not opt.evolve:  # Train normally | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
|  | import os | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -13,7 +14,11 @@ def init_seeds(seed=0): | ||||||
|         torch.backends.cudnn.benchmark = False |         torch.backends.cudnn.benchmark = False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def select_device(force_cpu=False, apex=False): | def select_device(device=None, force_cpu=False, apex=False): | ||||||
|  |     # Set environment variable if device is specified | ||||||
|  |     if device: | ||||||
|  |         os.environ['CUDA_VISIBLE_DEVICES'] = device | ||||||
|  | 
 | ||||||
|     # apex if mixed precision training https://github.com/NVIDIA/apex |     # apex if mixed precision training https://github.com/NVIDIA/apex | ||||||
|     cuda = False if force_cpu else torch.cuda.is_available() |     cuda = False if force_cpu else torch.cuda.is_available() | ||||||
|     device = torch.device('cuda:0' if cuda else 'cpu') |     device = torch.device('cuda:0' if cuda else 'cpu') | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue