updates
This commit is contained in:
parent
418269d739
commit
a52c0abf8d
12
train.py
12
train.py
|
@ -190,13 +190,19 @@ def train():
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
# Start training
|
# Model parameters
|
||||||
nb = len(dataloader) # number of batches
|
|
||||||
prebias = start_epoch == 0
|
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
model.arc = opt.arc # attach yolo architecture
|
model.arc = opt.arc # attach yolo architecture
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
|
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
|
||||||
|
# Model EMA
|
||||||
|
# ema = torch_utils.ModelEMA(model, decay=0.9997)
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
nb = len(dataloader) # number of batches
|
||||||
|
prebias = start_epoch == 0
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
|
@ -101,3 +103,48 @@ def load_classifier(name='resnet101', n=2):
|
||||||
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||||
model.last_linear.out_features = n
|
model.last_linear.out_features = n
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEMA:
|
||||||
|
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
||||||
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||||
|
This is intended to allow functionality like
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||||
|
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||||
|
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||||
|
smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||||
|
relative to your update count per epoch.
|
||||||
|
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||||
|
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||||
|
process, or after the training stops converging.
|
||||||
|
This class is sensitive where it is initialized in the sequence of model init,
|
||||||
|
GPU assignment and distributed training wrappers.
|
||||||
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, decay=0.9998, device=''):
|
||||||
|
# make a copy of the model for accumulating moving average of weights
|
||||||
|
self.ema = deepcopy(model)
|
||||||
|
self.ema.eval()
|
||||||
|
self.decay = decay
|
||||||
|
self.device = device # perform ema on different device from model if set
|
||||||
|
if device:
|
||||||
|
self.ema.to(device=device)
|
||||||
|
for p in self.ema.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
def update(self, model):
|
||||||
|
d = self.decay
|
||||||
|
with torch.no_grad():
|
||||||
|
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||||
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||||
|
else:
|
||||||
|
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||||
|
|
||||||
|
# self.ema.load_state_dict(
|
||||||
|
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||||
|
for k in msd.keys():
|
||||||
|
if esd[k].dtype.is_floating_point:
|
||||||
|
esd[k] *= d
|
||||||
|
esd[k] += (1. - d) * msd[k].detach()
|
||||||
|
|
Loading…
Reference in New Issue