diff --git a/train.py b/train.py index 1a32e3c0..88cf3c42 100644 --- a/train.py +++ b/train.py @@ -158,7 +158,7 @@ def train(): # plt.savefig('LR.png', dpi=300) # Initialize distributed training - if device.type != 'cpu' and torch.cuda.device_count() > 1: + if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): dist.init_process_group(backend='nccl', # 'distributed backend' init_method='tcp://127.0.0.1:9999', # distributed training init method world_size=1, # number of nodes for distributed training