This commit is contained in:
Glenn Jocher 2019-06-22 15:50:04 +02:00
parent 1a0385c77d
commit f501a0fc9d
1 changed files with 4 additions and 3 deletions

View File

@ -168,11 +168,12 @@ def train(
collate_fn=dataset.collate_fn)
# Mixed precision training https://github.com/NVIDIA/apex
# install help: https://github.com/NVIDIA/apex/issues/259
mixed_precision = False
if mixed_precision:
try:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
mixed_precision = True
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
mixed_precision = False
# Start training
model.hyp = hyp # attach hyperparameters to model