updates
This commit is contained in:
parent
7823473d2f
commit
84371f6811
2
train.py
2
train.py
|
@ -158,7 +158,7 @@ def train():
|
||||||
# plt.savefig('LR.png', dpi=300)
|
# plt.savefig('LR.png', dpi=300)
|
||||||
|
|
||||||
# Initialize distributed training
|
# Initialize distributed training
|
||||||
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
|
||||||
dist.init_process_group(backend='nccl', # 'distributed backend'
|
dist.init_process_group(backend='nccl', # 'distributed backend'
|
||||||
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
init_method='tcp://127.0.0.1:9999', # distributed training init method
|
||||||
world_size=1, # number of nodes for distributed training
|
world_size=1, # number of nodes for distributed training
|
||||||
|
|
Loading…
Reference in New Issue