EMA class updates

This commit is contained in:
Glenn Jocher 2020-03-16 17:51:40 -07:00
parent 1a12667ce1
commit 77c6c01970
1 changed files with 2 additions and 2 deletions

View File

@ -152,7 +152,7 @@ class ModelEMA:
msd, esd = model.module.state_dict(), self.ema.module.state_dict() msd, esd = model.module.state_dict(), self.ema.module.state_dict()
else: else:
msd, esd = model.state_dict(), self.ema.state_dict() msd, esd = model.state_dict(), self.ema.state_dict()
# self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
for k, v in esd.items(): for k, v in esd.items():
if v.dtype.is_floating_point: if v.dtype.is_floating_point:
v *= d v *= d
@ -162,4 +162,4 @@ class ModelEMA:
# Assign attributes (which may change during training) # Assign attributes (which may change during training)
for k in model.__dict__.keys(): for k in model.__dict__.keys():
if not k.startswith('_'): if not k.startswith('_'):
setattr(model, k, getattr(model, k)) setattr(self.ema, k, getattr(model, k))