diff --git a/train.py b/train.py index 3e983127..448f7774 100644 --- a/train.py +++ b/train.py @@ -63,7 +63,7 @@ def train( #initialize for distributed training if torch.cuda.device_count() > 1: - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) + dist.init_process_group(backend=opt.backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank) model = torch.nn.parallel.DistributedDataParallel(model) # Transfer learning (train only YOLO layers)