diff --git a/train.py b/train.py index 917fbda8..0b459286 100644 --- a/train.py +++ b/train.py @@ -99,9 +99,6 @@ def train(): optimizer.add_param_group({'params': pg2}) # add pg2 (biases) del pg0, pg1, pg2 - # https://github.com/alphadl/lookahead.pytorch - # optimizer = torch_utils.Lookahead(optimizer, k=5, alpha=0.5) - start_epoch = 0 best_fitness = 0.0 attempt_download(weights) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 9b321724..b984b265 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -94,74 +94,3 @@ def load_classifier(name='resnet101', n=2): model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) model.last_linear.out_features = n return model - - -from collections import defaultdict -from torch.optim import Optimizer - - -class Lookahead(Optimizer): - def __init__(self, optimizer, k=5, alpha=0.5): - self.optimizer = optimizer - self.k = k - self.alpha = alpha - self.param_groups = self.optimizer.param_groups - self.state = defaultdict(dict) - self.fast_state = self.optimizer.state - for group in self.param_groups: - group["counter"] = 0 - - def update(self, group): - for fast in group["params"]: - param_state = self.state[fast] - if "slow_param" not in param_state: - param_state["slow_param"] = torch.zeros_like(fast.data) - param_state["slow_param"].copy_(fast.data) - slow = param_state["slow_param"] - slow += (fast.data - slow) * self.alpha - fast.data.copy_(slow) - - def update_lookahead(self): - for group in self.param_groups: - self.update(group) - - def step(self, closure=None): - loss = self.optimizer.step(closure) - for group in self.param_groups: - if group["counter"] == 0: - self.update(group) - group["counter"] += 1 - if group["counter"] >= self.k: - group["counter"] = 0 - return loss - - def state_dict(self): - fast_state_dict = self.optimizer.state_dict() - slow_state = { - (id(k) if isinstance(k, torch.Tensor) else k): v - for k, v in self.state.items() - } - fast_state = fast_state_dict["state"] - param_groups = fast_state_dict["param_groups"] - return { - "fast_state": fast_state, - "slow_state": slow_state, - "param_groups": param_groups, - } - - def load_state_dict(self, state_dict): - slow_state_dict = { - "state": state_dict["slow_state"], - "param_groups": state_dict["param_groups"], - } - fast_state_dict = { - "state": state_dict["fast_state"], - "param_groups": state_dict["param_groups"], - } - super(Lookahead, self).load_state_dict(slow_state_dict) - self.optimizer.load_state_dict(fast_state_dict) - self.fast_state = self.optimizer.state - - def add_param_group(self, param_group): - param_group["counter"] = 0 - self.optimizer.add_param_group(param_group)