diff --git a/train.py b/train.py index 666da626..b1f0489e 100644 --- a/train.py +++ b/train.py @@ -168,11 +168,12 @@ def train( collate_fn=dataset.collate_fn) # Mixed precision training https://github.com/NVIDIA/apex - # install help: https://github.com/NVIDIA/apex/issues/259 - mixed_precision = False - if mixed_precision: + try: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + mixed_precision = True + except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259 + mixed_precision = False # Start training model.hyp = hyp # attach hyperparameters to model