diff --git a/train.py b/train.py index 7512730e..67fc65fd 100644 --- a/train.py +++ b/train.py @@ -197,7 +197,7 @@ def train(): model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights # Model EMA - # ema = torch_utils.ModelEMA(model, decay=0.9998) + ema = torch_utils.ModelEMA(model) # Start training nb = len(dataloader) # number of batches @@ -291,7 +291,7 @@ def train(): if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() - # ema.update(model) + ema.update(model) # Print batch results mloss = (mloss * i + loss_items) / (i + 1) # update mean losses @@ -305,7 +305,7 @@ def train(): scheduler.step() # Process epoch results - # ema.update_attr(model) + ema.update_attr(model) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 @@ -313,7 +313,7 @@ def train(): data, batch_size=batch_size * 2, img_size=img_size_test, - model=model, + model=ema.ema, conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed iou_thres=0.6, save_json=final_epoch and is_coco, @@ -347,7 +347,7 @@ def train(): chkpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': f.read(), - 'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), + 'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(), 'optimizer': None if final_epoch else optimizer.state_dict()} # Save last checkpoint @@ -377,7 +377,7 @@ def train(): if opt.bucket: # save to cloud os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket)) os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket)) - # os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket)) + os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket)) if not opt.evolve: plot_results() # save as results.png diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 0b7013b0..56e5bba6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,3 +1,4 @@ +import math import os import time from copy import deepcopy @@ -139,11 +140,12 @@ class ModelEMA: I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ - def __init__(self, model, decay=0.9998, device=''): + def __init__(self, model, decay=0.9999, device=''): # make a copy of the model for accumulating moving average of weights self.ema = deepcopy(model) self.ema.eval() - self.decay = decay + self.updates = 0 # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs) self.device = device # perform ema on different device from model if set if device: self.ema.to(device=device) @@ -151,7 +153,8 @@ class ModelEMA: p.requires_grad_(False) def update(self, model): - d = self.decay + self.updates += 1 + d = self.decay(self.updates) with torch.no_grad(): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): msd, esd = model.module.state_dict(), self.ema.module.state_dict()