updates
This commit is contained in:
parent
3cf8a13910
commit
7608047531
3
train.py
3
train.py
|
@ -99,9 +99,6 @@ def train():
|
||||||
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
||||||
del pg0, pg1, pg2
|
del pg0, pg1, pg2
|
||||||
|
|
||||||
# https://github.com/alphadl/lookahead.pytorch
|
|
||||||
# optimizer = torch_utils.Lookahead(optimizer, k=5, alpha=0.5)
|
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_fitness = 0.0
|
best_fitness = 0.0
|
||||||
attempt_download(weights)
|
attempt_download(weights)
|
||||||
|
|
|
@ -94,74 +94,3 @@ def load_classifier(name='resnet101', n=2):
|
||||||
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||||
model.last_linear.out_features = n
|
model.last_linear.out_features = n
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class Lookahead(Optimizer):
|
|
||||||
def __init__(self, optimizer, k=5, alpha=0.5):
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.k = k
|
|
||||||
self.alpha = alpha
|
|
||||||
self.param_groups = self.optimizer.param_groups
|
|
||||||
self.state = defaultdict(dict)
|
|
||||||
self.fast_state = self.optimizer.state
|
|
||||||
for group in self.param_groups:
|
|
||||||
group["counter"] = 0
|
|
||||||
|
|
||||||
def update(self, group):
|
|
||||||
for fast in group["params"]:
|
|
||||||
param_state = self.state[fast]
|
|
||||||
if "slow_param" not in param_state:
|
|
||||||
param_state["slow_param"] = torch.zeros_like(fast.data)
|
|
||||||
param_state["slow_param"].copy_(fast.data)
|
|
||||||
slow = param_state["slow_param"]
|
|
||||||
slow += (fast.data - slow) * self.alpha
|
|
||||||
fast.data.copy_(slow)
|
|
||||||
|
|
||||||
def update_lookahead(self):
|
|
||||||
for group in self.param_groups:
|
|
||||||
self.update(group)
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
loss = self.optimizer.step(closure)
|
|
||||||
for group in self.param_groups:
|
|
||||||
if group["counter"] == 0:
|
|
||||||
self.update(group)
|
|
||||||
group["counter"] += 1
|
|
||||||
if group["counter"] >= self.k:
|
|
||||||
group["counter"] = 0
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
fast_state_dict = self.optimizer.state_dict()
|
|
||||||
slow_state = {
|
|
||||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
|
||||||
for k, v in self.state.items()
|
|
||||||
}
|
|
||||||
fast_state = fast_state_dict["state"]
|
|
||||||
param_groups = fast_state_dict["param_groups"]
|
|
||||||
return {
|
|
||||||
"fast_state": fast_state,
|
|
||||||
"slow_state": slow_state,
|
|
||||||
"param_groups": param_groups,
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
slow_state_dict = {
|
|
||||||
"state": state_dict["slow_state"],
|
|
||||||
"param_groups": state_dict["param_groups"],
|
|
||||||
}
|
|
||||||
fast_state_dict = {
|
|
||||||
"state": state_dict["fast_state"],
|
|
||||||
"param_groups": state_dict["param_groups"],
|
|
||||||
}
|
|
||||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
|
||||||
self.optimizer.load_state_dict(fast_state_dict)
|
|
||||||
self.fast_state = self.optimizer.state
|
|
||||||
|
|
||||||
def add_param_group(self, param_group):
|
|
||||||
param_group["counter"] = 0
|
|
||||||
self.optimizer.add_param_group(param_group)
|
|
||||||
|
|
Loading…
Reference in New Issue