updates
This commit is contained in:
parent
27ca52c9ee
commit
7787090165
5
train.py
5
train.py
|
@ -113,7 +113,7 @@ def train(
|
|||
# Dataloader
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=opt.num_workers,
|
||||
num_workers=0,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
|
@ -170,6 +170,9 @@ def train(
|
|||
|
||||
# Compute loss
|
||||
loss, loss_items = compute_loss(pred, targets, model)
|
||||
if torch.isnan(loss):
|
||||
print('WARNING: nan loss detected, ending training')
|
||||
return results
|
||||
|
||||
# Compute gradient
|
||||
if mixed_precision:
|
||||
|
|
Loading…
Reference in New Issue