updates
This commit is contained in:
		
							parent
							
								
									eb4acecbb5
								
							
						
					
					
						commit
						334c7c94cf
					
				
							
								
								
									
										7
									
								
								train.py
								
								
								
								
							
							
						
						
									
										7
									
								
								train.py
								
								
								
								
							|  | @ -121,9 +121,7 @@ def train( | |||
|     if torch.cuda.device_count() > 1: | ||||
|         dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank) | ||||
|         model = torch.nn.parallel.DistributedDataParallel(model) | ||||
|         sampler = torch.utils.data.distributed.DistributedSampler(dataset) | ||||
|     else: | ||||
|         sampler = None | ||||
|         # sampler = torch.utils.data.distributed.DistributedSampler(dataset) | ||||
| 
 | ||||
|     # Dataloader | ||||
|     dataloader = DataLoader(dataset, | ||||
|  | @ -131,8 +129,7 @@ def train( | |||
|                             num_workers=opt.num_workers, | ||||
|                             shuffle=True, | ||||
|                             pin_memory=True, | ||||
|                             collate_fn=dataset.collate_fn, | ||||
|                             sampler=sampler) | ||||
|                             collate_fn=dataset.collate_fn) | ||||
| 
 | ||||
|     # Mixed precision training https://github.com/NVIDIA/apex | ||||
|     # install help: https://github.com/NVIDIA/apex/issues/259 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue