updates
This commit is contained in:
parent
95696d24c0
commit
f299d83f40
12
train.py
12
train.py
|
@ -97,6 +97,12 @@ def train(
|
|||
collate_fn=dataset.collate_fn,
|
||||
sampler=sampler)
|
||||
|
||||
# Mixed precision training https://github.com/NVIDIA/apex
|
||||
mixed_precision = False
|
||||
if mixed_precision:
|
||||
from apex import amp
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='01')
|
||||
|
||||
# Start training
|
||||
t = time.time()
|
||||
model_info(model)
|
||||
|
@ -145,7 +151,11 @@ def train(
|
|||
loss, loss_dict = compute_loss(pred, target_list)
|
||||
|
||||
# Compute gradient
|
||||
loss.backward()
|
||||
if mixed_precision:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
# Accumulate gradient for x batches before optimizing
|
||||
if (i + 1) % accumulate == 0 or (i + 1) == nB:
|
||||
|
|
Loading…
Reference in New Issue