updates
This commit is contained in:
parent
1a0385c77d
commit
f501a0fc9d
7
train.py
7
train.py
|
@ -168,11 +168,12 @@ def train(
|
|||
collate_fn=dataset.collate_fn)
|
||||
|
||||
# Mixed precision training https://github.com/NVIDIA/apex
|
||||
# install help: https://github.com/NVIDIA/apex/issues/259
|
||||
mixed_precision = False
|
||||
if mixed_precision:
|
||||
try:
|
||||
from apex import amp
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||
mixed_precision = True
|
||||
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
|
||||
mixed_precision = False
|
||||
|
||||
# Start training
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
|
|
Loading…
Reference in New Issue