This commit is contained in:
Glenn Jocher 2019-04-17 18:33:16 +02:00
parent 27ca52c9ee
commit 7787090165
1 changed files with 4 additions and 1 deletions

View File

@ -113,7 +113,7 @@ def train(
# Dataloader # Dataloader
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=opt.num_workers, num_workers=0,
shuffle=True, shuffle=True,
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
@ -170,6 +170,9 @@ def train(
# Compute loss # Compute loss
loss, loss_items = compute_loss(pred, targets, model) loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss):
print('WARNING: nan loss detected, ending training')
return results
# Compute gradient # Compute gradient
if mixed_precision: if mixed_precision: