EMA class updates
This commit is contained in:
parent
1a12667ce1
commit
77c6c01970
|
@ -152,7 +152,7 @@ class ModelEMA:
|
||||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||||
else:
|
else:
|
||||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||||
# self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
|
||||||
for k, v in esd.items():
|
for k, v in esd.items():
|
||||||
if v.dtype.is_floating_point:
|
if v.dtype.is_floating_point:
|
||||||
v *= d
|
v *= d
|
||||||
|
@ -162,4 +162,4 @@ class ModelEMA:
|
||||||
# Assign attributes (which may change during training)
|
# Assign attributes (which may change during training)
|
||||||
for k in model.__dict__.keys():
|
for k in model.__dict__.keys():
|
||||||
if not k.startswith('_'):
|
if not k.startswith('_'):
|
||||||
setattr(model, k, getattr(model, k))
|
setattr(self.ema, k, getattr(model, k))
|
||||||
|
|
Loading…
Reference in New Issue