diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 187d5142..ac38249c 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -152,7 +152,7 @@ class ModelEMA: msd, esd = model.module.state_dict(), self.ema.module.state_dict() else: msd, esd = model.state_dict(), self.ema.state_dict() - # self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point}) + for k, v in esd.items(): if v.dtype.is_floating_point: v *= d @@ -162,4 +162,4 @@ class ModelEMA: # Assign attributes (which may change during training) for k in model.__dict__.keys(): if not k.startswith('_'): - setattr(model, k, getattr(model, k)) + setattr(self.ema, k, getattr(model, k))