diff --git a/train.py b/train.py index 3ca70760..4d2b5c15 100644 --- a/train.py +++ b/train.py @@ -120,6 +120,7 @@ def train( sampler=sampler) # Mixed precision training https://github.com/NVIDIA/apex + # install help: https://github.com/NVIDIA/apex/issues/259 mixed_precision = False if mixed_precision: from apex import amp