updates
This commit is contained in:
		
							parent
							
								
									4286bba40f
								
							
						
					
					
						commit
						d5b5f74167
					
				|  | @ -17,12 +17,12 @@ def init_seeds(seed=0): | ||||||
| 
 | 
 | ||||||
| def select_device(device=None, apex=False): | def select_device(device=None, apex=False): | ||||||
|     if device == 'cpu': |     if device == 'cpu': | ||||||
|         force_cpu = True |         pass | ||||||
|     elif device:  # Set environment variable if device is specified |     elif device:  # Set environment variable if device is specified | ||||||
|         os.environ['CUDA_VISIBLE_DEVICES'] = device |         os.environ['CUDA_VISIBLE_DEVICES'] = device | ||||||
| 
 | 
 | ||||||
|     # apex if mixed precision training https://github.com/NVIDIA/apex |     # apex if mixed precision training https://github.com/NVIDIA/apex | ||||||
|     cuda = False if force_cpu else torch.cuda.is_available() |     cuda = False if device == 'cpu' else torch.cuda.is_available() | ||||||
|     device = torch.device('cuda:0' if cuda else 'cpu') |     device = torch.device('cuda:0' if cuda else 'cpu') | ||||||
| 
 | 
 | ||||||
|     if not cuda: |     if not cuda: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue