From 5daea5882fc6ecaa78ef6cb7cea078493a78b5ce Mon Sep 17 00:00:00 2001 From: perry0418 <34980036+perry0418@users.noreply.github.com> Date: Mon, 25 Mar 2019 16:13:21 +0800 Subject: [PATCH] Update train.py fix problem of multiple gpu training --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 6e461431..8bb87ce4 100644 --- a/train.py +++ b/train.py @@ -62,7 +62,7 @@ def train( #initialize for distributed training if torch.cuda.device_count() > 1: - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank) + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) model = torch.nn.parallel.DistributedDataParallel(model) # Dataloader @@ -215,6 +215,7 @@ if __name__ == '__main__': parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training') parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training') parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend') opt = parser.parse_args() print(opt, end='\n\n')