This commit is contained in:
Glenn Jocher 2019-04-13 16:02:45 +02:00
parent 95696d24c0
commit f299d83f40
1 changed files with 11 additions and 1 deletions

View File

@ -97,6 +97,12 @@ def train(
collate_fn=dataset.collate_fn,
sampler=sampler)
# Mixed precision training https://github.com/NVIDIA/apex
mixed_precision = False
if mixed_precision:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='01')
# Start training
t = time.time()
model_info(model)
@ -145,6 +151,10 @@ def train(
loss, loss_dict = compute_loss(pred, target_list)
# Compute gradient
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# Accumulate gradient for x batches before optimizing