updates
This commit is contained in:
parent
27ca52c9ee
commit
7787090165
5
train.py
5
train.py
|
@ -113,7 +113,7 @@ def train(
|
||||||
# Dataloader
|
# Dataloader
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=opt.num_workers,
|
num_workers=0,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
|
@ -170,6 +170,9 @@ def train(
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss, loss_items = compute_loss(pred, targets, model)
|
loss, loss_items = compute_loss(pred, targets, model)
|
||||||
|
if torch.isnan(loss):
|
||||||
|
print('WARNING: nan loss detected, ending training')
|
||||||
|
return results
|
||||||
|
|
||||||
# Compute gradient
|
# Compute gradient
|
||||||
if mixed_precision:
|
if mixed_precision:
|
||||||
|
|
Loading…
Reference in New Issue