Update train.py
This commit is contained in:
parent
cd51e1137b
commit
c7192f64c9
2
train.py
2
train.py
|
@ -63,7 +63,7 @@ def train(
|
|||
|
||||
#initialize for distributed training
|
||||
if torch.cuda.device_count() > 1:
|
||||
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
|
||||
dist.init_process_group(backend=opt.backend, init_method=opt.dist_url,world_size=opt.world_size, rank=opt.rank)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
|
||||
# Transfer learning (train only YOLO layers)
|
||||
|
|
Loading…
Reference in New Issue