From e613bbc88c8a4867b9d0e9e24a80ebdcb769966b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 29 Nov 2019 19:10:01 -0800 Subject: [PATCH] updates --- train.py | 3 ++ utils/torch_utils.py | 70 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/train.py b/train.py index 0038db97..ad328656 100644 --- a/train.py +++ b/train.py @@ -96,6 +96,9 @@ def train(): optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay del pg0, pg1 + # https://github.com/alphadl/lookahead.pytorch + # optimizer = torch_utils.Lookahead(optimizer, k=5, alpha=0.5) + cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0 best_fitness = float('inf') diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e7e15715..ecbcd306 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -96,3 +96,73 @@ def load_classifier(name='resnet101', n=2): model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) model.last_linear.out_features = n return model + + +from collections import defaultdict +from torch.optim import Optimizer + +class Lookahead(Optimizer): + def __init__(self, optimizer, k=5, alpha=0.5): + self.optimizer = optimizer + self.k = k + self.alpha = alpha + self.param_groups = self.optimizer.param_groups + self.state = defaultdict(dict) + self.fast_state = self.optimizer.state + for group in self.param_groups: + group["counter"] = 0 + + def update(self, group): + for fast in group["params"]: + param_state = self.state[fast] + if "slow_param" not in param_state: + param_state["slow_param"] = torch.zeros_like(fast.data) + param_state["slow_param"].copy_(fast.data) + slow = param_state["slow_param"] + slow += (fast.data - slow) * self.alpha + fast.data.copy_(slow) + + def update_lookahead(self): + for group in self.param_groups: + self.update(group) + + def step(self, closure=None): + loss = self.optimizer.step(closure) + for group in self.param_groups: + if group["counter"] == 0: + self.update(group) + group["counter"] += 1 + if group["counter"] >= self.k: + group["counter"] = 0 + return loss + + def state_dict(self): + fast_state_dict = self.optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict["state"] + param_groups = fast_state_dict["param_groups"] + return { + "fast_state": fast_state, + "slow_state": slow_state, + "param_groups": param_groups, + } + + def load_state_dict(self, state_dict): + slow_state_dict = { + "state": state_dict["slow_state"], + "param_groups": state_dict["param_groups"], + } + fast_state_dict = { + "state": state_dict["fast_state"], + "param_groups": state_dict["param_groups"], + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.optimizer.load_state_dict(fast_state_dict) + self.fast_state = self.optimizer.state + + def add_param_group(self, param_group): + param_group["counter"] = 0 + self.optimizer.add_param_group(param_group) \ No newline at end of file