This commit is contained in:
Glenn Jocher 2019-12-08 19:58:10 -08:00
parent 1bf717ef9c
commit 37fa9afaff
1 changed files with 2 additions and 1 deletions

View File

@ -101,6 +101,7 @@ def load_classifier(name='resnet101', n=2):
from collections import defaultdict from collections import defaultdict
from torch.optim import Optimizer from torch.optim import Optimizer
class Lookahead(Optimizer): class Lookahead(Optimizer):
def __init__(self, optimizer, k=5, alpha=0.5): def __init__(self, optimizer, k=5, alpha=0.5):
self.optimizer = optimizer self.optimizer = optimizer
@ -165,4 +166,4 @@ class Lookahead(Optimizer):
def add_param_group(self, param_group): def add_param_group(self, param_group):
param_group["counter"] = 0 param_group["counter"] = 0
self.optimizer.add_param_group(param_group) self.optimizer.add_param_group(param_group)