EMA class updates
This commit is contained in:
parent
666ba85ed3
commit
b89cc396af
5
train.py
5
train.py
|
@ -217,7 +217,6 @@ def train():
|
||||||
if prebias:
|
if prebias:
|
||||||
ne = 3 # number of prebias epochs
|
ne = 3 # number of prebias epochs
|
||||||
ps = 0.1, 0.9 # prebias settings (lr=0.1, momentum=0.9)
|
ps = 0.1, 0.9 # prebias settings (lr=0.1, momentum=0.9)
|
||||||
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0)
|
|
||||||
if epoch == ne:
|
if epoch == ne:
|
||||||
ps = hyp['lr0'], hyp['momentum'] # normal training settings
|
ps = hyp['lr0'], hyp['momentum'] # normal training settings
|
||||||
model.gr = 1.0 # giou loss ratio (obj_loss = giou)
|
model.gr = 1.0 # giou loss ratio (obj_loss = giou)
|
||||||
|
@ -307,6 +306,7 @@ def train():
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Process epoch results
|
# Process epoch results
|
||||||
|
# ema.update_attr(model)
|
||||||
final_epoch = epoch + 1 == epochs
|
final_epoch = epoch + 1 == epochs
|
||||||
if not opt.notest or final_epoch: # Calculate mAP
|
if not opt.notest or final_epoch: # Calculate mAP
|
||||||
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
|
||||||
|
@ -348,8 +348,7 @@ def train():
|
||||||
chkpt = {'epoch': epoch,
|
chkpt = {'epoch': epoch,
|
||||||
'best_fitness': best_fitness,
|
'best_fitness': best_fitness,
|
||||||
'training_results': f.read(),
|
'training_results': f.read(),
|
||||||
'model': model.module.state_dict() if type(
|
'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
|
||||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
|
||||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||||
|
|
||||||
# Save last checkpoint
|
# Save last checkpoint
|
||||||
|
|
|
@ -141,10 +141,14 @@ class ModelEMA:
|
||||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||||
else:
|
else:
|
||||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||||
|
# self.ema.load_state_dict({k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
||||||
|
for k, v in esd.items():
|
||||||
|
if v.dtype.is_floating_point:
|
||||||
|
v *= d
|
||||||
|
v += (1. - d) * msd[k].detach()
|
||||||
|
|
||||||
# self.ema.load_state_dict(
|
def update_attr(self, model):
|
||||||
# {k: esd[k] * d + (1 - d) * v.detach() for k, v in model.items() if v.dtype.is_floating_point})
|
# Assign attributes (which may change during training)
|
||||||
for k in msd.keys():
|
for k in model.__dict__.keys():
|
||||||
if esd[k].dtype.is_floating_point:
|
if not k.startswith('_'):
|
||||||
esd[k] *= d
|
self.ema.__setattr__(k, model.getattr(k))
|
||||||
esd[k] += (1. - d) * msd[k].detach()
|
|
||||||
|
|
Loading…
Reference in New Issue