EMA implemented by default
This commit is contained in:
parent
dc8e56b9f3
commit
9c5e76b93d
12
train.py
12
train.py
|
@ -197,7 +197,7 @@ def train():
|
|||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||
|
||||
# Model EMA
|
||||
# ema = torch_utils.ModelEMA(model, decay=0.9998)
|
||||
ema = torch_utils.ModelEMA(model)
|
||||
|
||||
# Start training
|
||||
nb = len(dataloader) # number of batches
|
||||
|
@ -291,7 +291,7 @@ def train():
|
|||
if ni % accumulate == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# ema.update(model)
|
||||
ema.update(model)
|
||||
|
||||
# Print batch results
|
||||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
||||
|
@ -305,7 +305,7 @@ def train():
|
|||
scheduler.step()
|
||||
|
||||
# Process epoch results
|
||||
# ema.update_attr(model)
|
||||
ema.update_attr(model)
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
||||
|
@ -313,7 +313,7 @@ def train():
|
|||
data,
|
||||
batch_size=batch_size * 2,
|
||||
img_size=img_size_test,
|
||||
model=model,
|
||||
model=ema.ema,
|
||||
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
||||
iou_thres=0.6,
|
||||
save_json=final_epoch and is_coco,
|
||||
|
@ -347,7 +347,7 @@ def train():
|
|||
chkpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': f.read(),
|
||||
'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
|
||||
'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
|
||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||
|
||||
# Save last checkpoint
|
||||
|
@ -377,7 +377,7 @@ def train():
|
|||
if opt.bucket: # save to cloud
|
||||
os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket))
|
||||
os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket))
|
||||
# os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket))
|
||||
os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket))
|
||||
|
||||
if not opt.evolve:
|
||||
plot_results() # save as results.png
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
@ -139,11 +140,12 @@ class ModelEMA:
|
|||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9998, device=''):
|
||||
def __init__(self, model, decay=0.9999, device=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.updates = 0 # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs)
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
|
@ -151,7 +153,8 @@ class ModelEMA:
|
|||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
d = self.decay
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
with torch.no_grad():
|
||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
|
|
Loading…
Reference in New Issue