updates
This commit is contained in:
parent
1a0385c77d
commit
f501a0fc9d
7
train.py
7
train.py
|
@ -168,11 +168,12 @@ def train(
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
# Mixed precision training https://github.com/NVIDIA/apex
|
# Mixed precision training https://github.com/NVIDIA/apex
|
||||||
# install help: https://github.com/NVIDIA/apex/issues/259
|
try:
|
||||||
mixed_precision = False
|
|
||||||
if mixed_precision:
|
|
||||||
from apex import amp
|
from apex import amp
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
|
mixed_precision = True
|
||||||
|
except: # not installed: install help: https://github.com/NVIDIA/apex/issues/259
|
||||||
|
mixed_precision = False
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
|
|
Loading…
Reference in New Issue