Update train.py

fix problem of multiple gpu training
This commit is contained in:
perry0418 2019-03-25 16:13:21 +08:00 committed by GitHub
parent 4884508110
commit 5daea5882f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -62,7 +62,7 @@ def train(
#initialize for distributed training #initialize for distributed training
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=args.rank) dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model)
# Dataloader # Dataloader
@ -215,6 +215,7 @@ if __name__ == '__main__':
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training') parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,help='url used to set up distributed training')
parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training') parser.add_argument('--rank', default=-1, type=int,help='node rank for distributed training')
parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training') parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,help='distributed backend')
opt = parser.parse_args() opt = parser.parse_args()
print(opt, end='\n\n') print(opt, end='\n\n')