updates
This commit is contained in:
		
							parent
							
								
									d20131b2e8
								
							
						
					
					
						commit
						91eaf2f8fe
					
				
							
								
								
									
										16
									
								
								train.py
								
								
								
								
							
							
						
						
									
										16
									
								
								train.py
								
								
								
								
							|  | @ -10,6 +10,7 @@ import test  # import test.py to get mAP after each epoch | |||
| from models import * | ||||
| from utils.datasets import * | ||||
| from utils.utils import * | ||||
| from utils.adabound import * | ||||
| 
 | ||||
| # 320 --epochs 1 | ||||
| #      0.109      0.297       0.15      0.126       7.04      1.666      4.062     0.1845       42.6       3.34      12.61      8.338     0.2705      0.001         -4        0.9     0.0005 a  320 giou + best_anchor False | ||||
|  | @ -89,7 +90,8 @@ def train(cfg, | |||
|     model = Darknet(cfg).to(device) | ||||
| 
 | ||||
|     # Optimizer | ||||
|     optimizer = optim.SGD(model.parameters(), lr=hyp['lr0'], momentum=hyp['momentum'], weight_decay=hyp['weight_decay']) | ||||
|     optimizer = optim.SGD(model.parameters(), lr=hyp['lr0'], momentum=hyp['momentum'], weight_decay=hyp['weight_decay'], nesterov=True) | ||||
|     # optimizer = AdaBound(model.parameters(), lr=hyp['lr0'], final_lr=0.1) | ||||
| 
 | ||||
|     cutoff = -1  # backbone reaches to cutoff layer | ||||
|     start_epoch = 0 | ||||
|  | @ -192,7 +194,7 @@ def train(cfg, | |||
|     nb = len(dataloader) | ||||
|     maps = np.zeros(nc)  # mAP per class | ||||
|     results = (0, 0, 0, 0, 0)  # P, R, mAP, F1, test_loss | ||||
|     n_burnin = min(round(nb / 5 + 1), 1000)  # burn-in batches | ||||
|     # n_burnin = min(round(nb / 5 + 1), 1000)  # burn-in batches | ||||
|     t0 = time.time() | ||||
|     for epoch in range(start_epoch, epochs): | ||||
|         model.train() | ||||
|  | @ -234,11 +236,11 @@ def train(cfg, | |||
|                 plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i) | ||||
| 
 | ||||
|             # SGD burn-in | ||||
|             if epoch == 0 and i <= n_burnin: | ||||
|                 g = (i / n_burnin) ** 4  # gain | ||||
|                 for x in optimizer.param_groups: | ||||
|                     x['lr'] = hyp['lr0'] * g | ||||
|                     x['weight_decay'] = hyp['weight_decay'] * g | ||||
|             # if epoch == 0 and i <= n_burnin: | ||||
|             #     g = (i / n_burnin) ** 4  # gain | ||||
|             #     for x in optimizer.param_groups: | ||||
|             #         x['lr'] = hyp['lr0'] * g | ||||
|             #         x['weight_decay'] = hyp['weight_decay'] * g | ||||
| 
 | ||||
|             # Run model | ||||
|             pred = model(imgs) | ||||
|  |  | |||
|  | @ -0,0 +1,235 @@ | |||
| import math | ||||
| import torch | ||||
| from torch.optim import Optimizer | ||||
| 
 | ||||
| 
 | ||||
| class AdaBound(Optimizer): | ||||
|     """Implements AdaBound algorithm. | ||||
|     It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. | ||||
|     Arguments: | ||||
|         params (iterable): iterable of parameters to optimize or dicts defining | ||||
|             parameter groups | ||||
|         lr (float, optional): Adam learning rate (default: 1e-3) | ||||
|         betas (Tuple[float, float], optional): coefficients used for computing | ||||
|             running averages of gradient and its square (default: (0.9, 0.999)) | ||||
|         final_lr (float, optional): final (SGD) learning rate (default: 0.1) | ||||
|         gamma (float, optional): convergence speed of the bound functions (default: 1e-3) | ||||
|         eps (float, optional): term added to the denominator to improve | ||||
|             numerical stability (default: 1e-8) | ||||
|         weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||||
|         amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm | ||||
|     .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: | ||||
|         https://openreview.net/forum?id=Bkg3g2R9FX | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, | ||||
|                  eps=1e-8, weight_decay=0, amsbound=False): | ||||
|         if not 0.0 <= lr: | ||||
|             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||
|         if not 0.0 <= eps: | ||||
|             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||
|         if not 0.0 <= betas[0] < 1.0: | ||||
|             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||
|         if not 0.0 <= betas[1] < 1.0: | ||||
|             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||
|         if not 0.0 <= final_lr: | ||||
|             raise ValueError("Invalid final learning rate: {}".format(final_lr)) | ||||
|         if not 0.0 <= gamma < 1.0: | ||||
|             raise ValueError("Invalid gamma parameter: {}".format(gamma)) | ||||
|         defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, | ||||
|                         weight_decay=weight_decay, amsbound=amsbound) | ||||
|         super(AdaBound, self).__init__(params, defaults) | ||||
| 
 | ||||
|         self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) | ||||
| 
 | ||||
|     def __setstate__(self, state): | ||||
|         super(AdaBound, self).__setstate__(state) | ||||
|         for group in self.param_groups: | ||||
|             group.setdefault('amsbound', False) | ||||
| 
 | ||||
|     def step(self, closure=None): | ||||
|         """Performs a single optimization step. | ||||
|         Arguments: | ||||
|             closure (callable, optional): A closure that reevaluates the model | ||||
|                 and returns the loss. | ||||
|         """ | ||||
|         loss = None | ||||
|         if closure is not None: | ||||
|             loss = closure() | ||||
| 
 | ||||
|         for group, base_lr in zip(self.param_groups, self.base_lrs): | ||||
|             for p in group['params']: | ||||
|                 if p.grad is None: | ||||
|                     continue | ||||
|                 grad = p.grad.data | ||||
|                 if grad.is_sparse: | ||||
|                     raise RuntimeError( | ||||
|                         'Adam does not support sparse gradients, please consider SparseAdam instead') | ||||
|                 amsbound = group['amsbound'] | ||||
| 
 | ||||
|                 state = self.state[p] | ||||
| 
 | ||||
|                 # State initialization | ||||
|                 if len(state) == 0: | ||||
|                     state['step'] = 0 | ||||
|                     # Exponential moving average of gradient values | ||||
|                     state['exp_avg'] = torch.zeros_like(p.data) | ||||
|                     # Exponential moving average of squared gradient values | ||||
|                     state['exp_avg_sq'] = torch.zeros_like(p.data) | ||||
|                     if amsbound: | ||||
|                         # Maintains max of all exp. moving avg. of sq. grad. values | ||||
|                         state['max_exp_avg_sq'] = torch.zeros_like(p.data) | ||||
| 
 | ||||
|                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||||
|                 if amsbound: | ||||
|                     max_exp_avg_sq = state['max_exp_avg_sq'] | ||||
|                 beta1, beta2 = group['betas'] | ||||
| 
 | ||||
|                 state['step'] += 1 | ||||
| 
 | ||||
|                 if group['weight_decay'] != 0: | ||||
|                     grad = grad.add(group['weight_decay'], p.data) | ||||
| 
 | ||||
|                 # Decay the first and second moment running average coefficient | ||||
|                 exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||||
|                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||||
|                 if amsbound: | ||||
|                     # Maintains the maximum of all 2nd moment running avg. till now | ||||
|                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | ||||
|                     # Use the max. for normalizing running avg. of gradient | ||||
|                     denom = max_exp_avg_sq.sqrt().add_(group['eps']) | ||||
|                 else: | ||||
|                     denom = exp_avg_sq.sqrt().add_(group['eps']) | ||||
| 
 | ||||
|                 bias_correction1 = 1 - beta1 ** state['step'] | ||||
|                 bias_correction2 = 1 - beta2 ** state['step'] | ||||
|                 step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | ||||
| 
 | ||||
|                 # Applies bounds on actual learning rate | ||||
|                 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay | ||||
|                 final_lr = group['final_lr'] * group['lr'] / base_lr | ||||
|                 lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) | ||||
|                 upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) | ||||
|                 step_size = torch.full_like(denom, step_size) | ||||
|                 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) | ||||
| 
 | ||||
|                 p.data.add_(-step_size) | ||||
| 
 | ||||
|         return loss | ||||
| 
 | ||||
| 
 | ||||
| class AdaBoundW(Optimizer): | ||||
|     """Implements AdaBound algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101) | ||||
|     It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. | ||||
|     Arguments: | ||||
|         params (iterable): iterable of parameters to optimize or dicts defining | ||||
|             parameter groups | ||||
|         lr (float, optional): Adam learning rate (default: 1e-3) | ||||
|         betas (Tuple[float, float], optional): coefficients used for computing | ||||
|             running averages of gradient and its square (default: (0.9, 0.999)) | ||||
|         final_lr (float, optional): final (SGD) learning rate (default: 0.1) | ||||
|         gamma (float, optional): convergence speed of the bound functions (default: 1e-3) | ||||
|         eps (float, optional): term added to the denominator to improve | ||||
|             numerical stability (default: 1e-8) | ||||
|         weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||||
|         amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm | ||||
|     .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: | ||||
|         https://openreview.net/forum?id=Bkg3g2R9FX | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, | ||||
|                  eps=1e-8, weight_decay=0, amsbound=False): | ||||
|         if not 0.0 <= lr: | ||||
|             raise ValueError("Invalid learning rate: {}".format(lr)) | ||||
|         if not 0.0 <= eps: | ||||
|             raise ValueError("Invalid epsilon value: {}".format(eps)) | ||||
|         if not 0.0 <= betas[0] < 1.0: | ||||
|             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||||
|         if not 0.0 <= betas[1] < 1.0: | ||||
|             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||||
|         if not 0.0 <= final_lr: | ||||
|             raise ValueError("Invalid final learning rate: {}".format(final_lr)) | ||||
|         if not 0.0 <= gamma < 1.0: | ||||
|             raise ValueError("Invalid gamma parameter: {}".format(gamma)) | ||||
|         defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, | ||||
|                         weight_decay=weight_decay, amsbound=amsbound) | ||||
|         super(AdaBoundW, self).__init__(params, defaults) | ||||
| 
 | ||||
|         self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) | ||||
| 
 | ||||
|     def __setstate__(self, state): | ||||
|         super(AdaBoundW, self).__setstate__(state) | ||||
|         for group in self.param_groups: | ||||
|             group.setdefault('amsbound', False) | ||||
| 
 | ||||
|     def step(self, closure=None): | ||||
|         """Performs a single optimization step. | ||||
|         Arguments: | ||||
|             closure (callable, optional): A closure that reevaluates the model | ||||
|                 and returns the loss. | ||||
|         """ | ||||
|         loss = None | ||||
|         if closure is not None: | ||||
|             loss = closure() | ||||
| 
 | ||||
|         for group, base_lr in zip(self.param_groups, self.base_lrs): | ||||
|             for p in group['params']: | ||||
|                 if p.grad is None: | ||||
|                     continue | ||||
|                 grad = p.grad.data | ||||
|                 if grad.is_sparse: | ||||
|                     raise RuntimeError( | ||||
|                         'Adam does not support sparse gradients, please consider SparseAdam instead') | ||||
|                 amsbound = group['amsbound'] | ||||
| 
 | ||||
|                 state = self.state[p] | ||||
| 
 | ||||
|                 # State initialization | ||||
|                 if len(state) == 0: | ||||
|                     state['step'] = 0 | ||||
|                     # Exponential moving average of gradient values | ||||
|                     state['exp_avg'] = torch.zeros_like(p.data) | ||||
|                     # Exponential moving average of squared gradient values | ||||
|                     state['exp_avg_sq'] = torch.zeros_like(p.data) | ||||
|                     if amsbound: | ||||
|                         # Maintains max of all exp. moving avg. of sq. grad. values | ||||
|                         state['max_exp_avg_sq'] = torch.zeros_like(p.data) | ||||
| 
 | ||||
|                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||||
|                 if amsbound: | ||||
|                     max_exp_avg_sq = state['max_exp_avg_sq'] | ||||
|                 beta1, beta2 = group['betas'] | ||||
| 
 | ||||
|                 state['step'] += 1 | ||||
| 
 | ||||
|                 # Decay the first and second moment running average coefficient | ||||
|                 exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||||
|                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||||
|                 if amsbound: | ||||
|                     # Maintains the maximum of all 2nd moment running avg. till now | ||||
|                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | ||||
|                     # Use the max. for normalizing running avg. of gradient | ||||
|                     denom = max_exp_avg_sq.sqrt().add_(group['eps']) | ||||
|                 else: | ||||
|                     denom = exp_avg_sq.sqrt().add_(group['eps']) | ||||
| 
 | ||||
|                 bias_correction1 = 1 - beta1 ** state['step'] | ||||
|                 bias_correction2 = 1 - beta2 ** state['step'] | ||||
|                 step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | ||||
| 
 | ||||
|                 # Applies bounds on actual learning rate | ||||
|                 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay | ||||
|                 final_lr = group['final_lr'] * group['lr'] / base_lr | ||||
|                 lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) | ||||
|                 upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) | ||||
|                 step_size = torch.full_like(denom, step_size) | ||||
|                 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) | ||||
| 
 | ||||
|                 if group['weight_decay'] != 0: | ||||
|                     decayed_weights = torch.mul(p.data, group['weight_decay']) | ||||
|                     p.data.add_(-step_size) | ||||
|                     p.data.sub_(decayed_weights) | ||||
|                 else: | ||||
|                     p.data.add_(-step_size) | ||||
| 
 | ||||
|         return loss | ||||
		Loading…
	
		Reference in New Issue