diff --git a/train.py b/train.py index 09eb70c3..9699cd2f 100644 --- a/train.py +++ b/train.py @@ -178,7 +178,7 @@ def train(): init_method='tcp://127.0.0.1:9999', # distributed training init method world_size=1, # number of nodes for distributed training rank=0) # distributed training node rank - model = torch.nn.parallel.DistributedDataParallel(model) + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) model.yolo_layers = model.module.yolo_layers # move yolo layer indices to top level # Dataset