This commit is contained in:
Glenn Jocher 2020-03-13 20:12:54 -07:00
parent 418269d739
commit a52c0abf8d
2 changed files with 56 additions and 3 deletions

View File

@ -190,13 +190,19 @@ def train():
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)
# Start training # Model parameters
nb = len(dataloader) # number of batches
prebias = start_epoch == 0
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
model.arc = opt.arc # attach yolo architecture model.arc = opt.arc # attach yolo architecture
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
# Model EMA
# ema = torch_utils.ModelEMA(model, decay=0.9997)
# Start training
nb = len(dataloader) # number of batches
prebias = start_epoch == 0
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'

View File

@ -1,8 +1,10 @@
import os import os
import time import time
from copy import deepcopy
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn
def init_seeds(seed=0): def init_seeds(seed=0):
@ -101,3 +103,48 @@ def load_classifier(name='resnet101', n=2):
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
model.last_linear.out_features = n model.last_linear.out_features = n
return model return model
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9998, device=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
d = self.decay
with torch.no_grad():
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
else:
msd, esd = model.state_dict(), self.ema.state_dict()
# self.ema.load_state_dict(
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
for k in msd.keys():
if esd[k].dtype.is_floating_point:
esd[k] *= d
esd[k] += (1. - d) * msd[k].detach()