updates
This commit is contained in:
parent
1bf717ef9c
commit
37fa9afaff
|
@ -101,6 +101,7 @@ def load_classifier(name='resnet101', n=2):
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class Lookahead(Optimizer):
|
class Lookahead(Optimizer):
|
||||||
def __init__(self, optimizer, k=5, alpha=0.5):
|
def __init__(self, optimizer, k=5, alpha=0.5):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
@ -165,4 +166,4 @@ class Lookahead(Optimizer):
|
||||||
|
|
||||||
def add_param_group(self, param_group):
|
def add_param_group(self, param_group):
|
||||||
param_group["counter"] = 0
|
param_group["counter"] = 0
|
||||||
self.optimizer.add_param_group(param_group)
|
self.optimizer.add_param_group(param_group)
|
||||||
|
|
Loading…
Reference in New Issue