updates
This commit is contained in:
parent
95696d24c0
commit
f299d83f40
10
train.py
10
train.py
|
@ -97,6 +97,12 @@ def train(
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
|
|
||||||
|
# Mixed precision training https://github.com/NVIDIA/apex
|
||||||
|
mixed_precision = False
|
||||||
|
if mixed_precision:
|
||||||
|
from apex import amp
|
||||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level='01')
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
t = time.time()
|
t = time.time()
|
||||||
model_info(model)
|
model_info(model)
|
||||||
|
@ -145,6 +151,10 @@ def train(
|
||||||
loss, loss_dict = compute_loss(pred, target_list)
|
loss, loss_dict = compute_loss(pred, target_list)
|
||||||
|
|
||||||
# Compute gradient
|
# Compute gradient
|
||||||
|
if mixed_precision:
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# Accumulate gradient for x batches before optimizing
|
# Accumulate gradient for x batches before optimizing
|
||||||
|
|
Loading…
Reference in New Issue