This commit is contained in:
Glenn Jocher 2019-04-22 23:27:31 +02:00
parent eb4acecbb5
commit 334c7c94cf
1 changed files with 2 additions and 5 deletions

View File

@ -121,9 +121,7 @@ def train(
if torch.cuda.device_count() > 1:
dist.init_process_group(backend=opt.backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = None
# sampler = torch.utils.data.distributed.DistributedSampler(dataset)
# Dataloader
dataloader = DataLoader(dataset,
@ -131,8 +129,7 @@ def train(
num_workers=opt.num_workers,
shuffle=True,
pin_memory=True,
collate_fn=dataset.collate_fn,
sampler=sampler)
collate_fn=dataset.collate_fn)
# Mixed precision training https://github.com/NVIDIA/apex
# install help: https://github.com/NVIDIA/apex/issues/259