updates
This commit is contained in:
parent
418269d739
commit
a52c0abf8d
12
train.py
12
train.py
|
@ -190,13 +190,19 @@ def train():
|
|||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
|
||||
# Start training
|
||||
nb = len(dataloader) # number of batches
|
||||
prebias = start_epoch == 0
|
||||
# Model parameters
|
||||
model.nc = nc # attach number of classes to model
|
||||
model.arc = opt.arc # attach yolo architecture
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||
|
||||
# Model EMA
|
||||
# ema = torch_utils.ModelEMA(model, decay=0.9997)
|
||||
|
||||
# Start training
|
||||
nb = len(dataloader) # number of batches
|
||||
prebias = start_epoch == 0
|
||||
maps = np.zeros(nc) # mAP per class
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
|
@ -101,3 +103,48 @@ def load_classifier(name='resnet101', n=2):
|
|||
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||
model.last_linear.out_features = n
|
||||
return model
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||
relative to your update count per epoch.
|
||||
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
process, or after the training stops converging.
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9998, device=''):
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.ema = deepcopy(model)
|
||||
self.ema.eval()
|
||||
self.decay = decay
|
||||
self.device = device # perform ema on different device from model if set
|
||||
if device:
|
||||
self.ema.to(device=device)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
d = self.decay
|
||||
with torch.no_grad():
|
||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
else:
|
||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||
|
||||
# self.ema.load_state_dict(
|
||||
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||
for k in msd.keys():
|
||||
if esd[k].dtype.is_floating_point:
|
||||
esd[k] *= d
|
||||
esd[k] += (1. - d) * msd[k].detach()
|
||||
|
|
Loading…
Reference in New Issue