EMA implemented by default

This commit is contained in:
Glenn Jocher 2020-03-29 13:14:54 -07:00
parent dc8e56b9f3
commit 9c5e76b93d
2 changed files with 12 additions and 9 deletions

View File

@ -197,7 +197,7 @@ def train():
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
# Model EMA # Model EMA
# ema = torch_utils.ModelEMA(model, decay=0.9998) ema = torch_utils.ModelEMA(model)
# Start training # Start training
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
@ -291,7 +291,7 @@ def train():
if ni % accumulate == 0: if ni % accumulate == 0:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# ema.update(model) ema.update(model)
# Print batch results # Print batch results
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
@ -305,7 +305,7 @@ def train():
scheduler.step() scheduler.step()
# Process epoch results # Process epoch results
# ema.update_attr(model) ema.update_attr(model)
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP if not opt.notest or final_epoch: # Calculate mAP
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
@ -313,7 +313,7 @@ def train():
data, data,
batch_size=batch_size * 2, batch_size=batch_size * 2,
img_size=img_size_test, img_size=img_size_test,
model=model, model=ema.ema,
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
iou_thres=0.6, iou_thres=0.6,
save_json=final_epoch and is_coco, save_json=final_epoch and is_coco,
@ -347,7 +347,7 @@ def train():
chkpt = {'epoch': epoch, chkpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': f.read(), 'training_results': f.read(),
'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), 'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
'optimizer': None if final_epoch else optimizer.state_dict()} 'optimizer': None if final_epoch else optimizer.state_dict()}
# Save last checkpoint # Save last checkpoint
@ -377,7 +377,7 @@ def train():
if opt.bucket: # save to cloud if opt.bucket: # save to cloud
os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket)) os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket))
os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket)) os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket))
# os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket)) os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket))
if not opt.evolve: if not opt.evolve:
plot_results() # save as results.png plot_results() # save as results.png

View File

@ -1,3 +1,4 @@
import math
import os import os
import time import time
from copy import deepcopy from copy import deepcopy
@ -139,11 +140,12 @@ class ModelEMA:
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
""" """
def __init__(self, model, decay=0.9998, device=''): def __init__(self, model, decay=0.9999, device=''):
# make a copy of the model for accumulating moving average of weights # make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model) self.ema = deepcopy(model)
self.ema.eval() self.ema.eval()
self.decay = decay self.updates = 0 # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs)
self.device = device # perform ema on different device from model if set self.device = device # perform ema on different device from model if set
if device: if device:
self.ema.to(device=device) self.ema.to(device=device)
@ -151,7 +153,8 @@ class ModelEMA:
p.requires_grad_(False) p.requires_grad_(False)
def update(self, model): def update(self, model):
d = self.decay self.updates += 1
d = self.decay(self.updates)
with torch.no_grad(): with torch.no_grad():
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
msd, esd = model.module.state_dict(), self.ema.module.state_dict() msd, esd = model.module.state_dict(), self.ema.module.state_dict()