EMA implemented by default
This commit is contained in:
parent
dc8e56b9f3
commit
9c5e76b93d
12
train.py
12
train.py
|
@ -197,7 +197,7 @@ def train():
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
|
||||||
# Model EMA
|
# Model EMA
|
||||||
# ema = torch_utils.ModelEMA(model, decay=0.9998)
|
ema = torch_utils.ModelEMA(model)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
nb = len(dataloader) # number of batches
|
nb = len(dataloader) # number of batches
|
||||||
|
@ -291,7 +291,7 @@ def train():
|
||||||
if ni % accumulate == 0:
|
if ni % accumulate == 0:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# ema.update(model)
|
ema.update(model)
|
||||||
|
|
||||||
# Print batch results
|
# Print batch results
|
||||||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
||||||
|
@ -305,7 +305,7 @@ def train():
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Process epoch results
|
# Process epoch results
|
||||||
# ema.update_attr(model)
|
ema.update_attr(model)
|
||||||
final_epoch = epoch + 1 == epochs
|
final_epoch = epoch + 1 == epochs
|
||||||
if not opt.notest or final_epoch: # Calculate mAP
|
if not opt.notest or final_epoch: # Calculate mAP
|
||||||
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
||||||
|
@ -313,7 +313,7 @@ def train():
|
||||||
data,
|
data,
|
||||||
batch_size=batch_size * 2,
|
batch_size=batch_size * 2,
|
||||||
img_size=img_size_test,
|
img_size=img_size_test,
|
||||||
model=model,
|
model=ema.ema,
|
||||||
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
conf_thres=0.001 if final_epoch else 0.01, # 0.001 for best mAP, 0.01 for speed
|
||||||
iou_thres=0.6,
|
iou_thres=0.6,
|
||||||
save_json=final_epoch and is_coco,
|
save_json=final_epoch and is_coco,
|
||||||
|
@ -347,7 +347,7 @@ def train():
|
||||||
chkpt = {'epoch': epoch,
|
chkpt = {'epoch': epoch,
|
||||||
'best_fitness': best_fitness,
|
'best_fitness': best_fitness,
|
||||||
'training_results': f.read(),
|
'training_results': f.read(),
|
||||||
'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
|
'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
|
||||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||||
|
|
||||||
# Save last checkpoint
|
# Save last checkpoint
|
||||||
|
@ -377,7 +377,7 @@ def train():
|
||||||
if opt.bucket: # save to cloud
|
if opt.bucket: # save to cloud
|
||||||
os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket))
|
os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket))
|
||||||
os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket))
|
os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket))
|
||||||
# os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket))
|
os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket))
|
||||||
|
|
||||||
if not opt.evolve:
|
if not opt.evolve:
|
||||||
plot_results() # save as results.png
|
plot_results() # save as results.png
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -139,11 +140,12 @@ class ModelEMA:
|
||||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9998, device=''):
|
def __init__(self, model, decay=0.9999, device=''):
|
||||||
# make a copy of the model for accumulating moving average of weights
|
# make a copy of the model for accumulating moving average of weights
|
||||||
self.ema = deepcopy(model)
|
self.ema = deepcopy(model)
|
||||||
self.ema.eval()
|
self.ema.eval()
|
||||||
self.decay = decay
|
self.updates = 0 # number of EMA updates
|
||||||
|
self.decay = lambda x: decay * (1 - math.exp(-x / 1000)) # decay exponential ramp (to help early epochs)
|
||||||
self.device = device # perform ema on different device from model if set
|
self.device = device # perform ema on different device from model if set
|
||||||
if device:
|
if device:
|
||||||
self.ema.to(device=device)
|
self.ema.to(device=device)
|
||||||
|
@ -151,7 +153,8 @@ class ModelEMA:
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
d = self.decay
|
self.updates += 1
|
||||||
|
d = self.decay(self.updates)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||||
|
|
Loading…
Reference in New Issue