diff --git a/train.py b/train.py index e2ba4b0b..18d2d0f1 100644 --- a/train.py +++ b/train.py @@ -217,7 +217,6 @@ def train(): if prebias: ne = 3 # number of prebias epochs ps = 0.1, 0.9 # prebias settings (lr=0.1, momentum=0.9) - model.gr = 0.0 # giou loss ratio (obj_loss = 1.0) if epoch == ne: ps = hyp['lr0'], hyp['momentum'] # normal training settings model.gr = 1.0 # giou loss ratio (obj_loss = giou) @@ -307,6 +306,7 @@ def train(): scheduler.step() # Process epoch results + # ema.update_attr(model) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 @@ -348,8 +348,7 @@ def train(): chkpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': f.read(), - 'model': model.module.state_dict() if type( - model) is nn.parallel.DistributedDataParallel else model.state_dict(), + 'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), 'optimizer': None if final_epoch else optimizer.state_dict()} # Save last checkpoint diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c706f7f5..b394b642 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -141,10 +141,14 @@ class ModelEMA: msd, esd = model.module.state_dict(), self.ema.module.state_dict() else: msd, esd = model.state_dict(), self.ema.state_dict() + # self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point}) + for k, v in esd.items(): + if v.dtype.is_floating_point: + v *= d + v += (1. - d) * msd[k].detach() - # self.ema.load_state_dict( - # {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point}) - for k in msd.keys(): - if esd[k].dtype.is_floating_point: - esd[k] *= d - esd[k] += (1. - d) * msd[k].detach() + def update_attr(self, model): + # Assign attributes (which may change during training) + for k in model.__dict__.keys(): + if not k.startswith('_'): + self.ema.__setattr__(k, model.getattr(k))