diff --git a/train.py b/train.py index 6e461431..8bb87ce4 100644 --- a/train.py +++ b/train.py @@ -62,7 +62,7 @@ def train( #initialize for distributed training if torch.cuda.device_count() > 1: - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank) + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) model = torch.nn.parallel.DistributedDataParallel(model) # Dataloader @@ -215,6 +215,7 @@ if __name__ == '__main__': parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training') parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training') parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend') opt = parser.parse_args() print(opt, end='\n\n')