Update train.py

This commit is contained in:
Glenn Jocher 2019-03-25 15:03:13 +01:00 committed by GitHub
parent cd51e1137b
commit c7192f64c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -63,7 +63,7 @@ def train(
#initialize for distributed training #initialize for distributed training
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) dist.init_process_group(backend=opt.backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
model = torch.nn.parallel.DistributedDataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model)
# Transfer learning (train only YOLO layers) # Transfer learning (train only YOLO layers)