EMA implemented by default
This commit is contained in:
		
							parent
							
								
									dc8e56b9f3
								
							
						
					
					
						commit
						9c5e76b93d
					
				
							
								
								
									
										12
									
								
								train.py
								
								
								
								
							
							
						
						
									
										12
									
								
								train.py
								
								
								
								
							|  | @ -197,7 +197,7 @@ def train(): | |||
|     model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights | ||||
| 
 | ||||
|     # Model EMA | ||||
|     # ema = torch_utils.ModelEMA(model, decay=0.9998) | ||||
|     ema = torch_utils.ModelEMA(model) | ||||
| 
 | ||||
|     # Start training | ||||
|     nb = len(dataloader)  # number of batches | ||||
|  | @ -291,7 +291,7 @@ def train(): | |||
|             if ni % accumulate == 0: | ||||
|                 optimizer.step() | ||||
|                 optimizer.zero_grad() | ||||
|                 # ema.update(model) | ||||
|                 ema.update(model) | ||||
| 
 | ||||
|             # Print batch results | ||||
|             mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses | ||||
|  | @ -305,7 +305,7 @@ def train(): | |||
|         scheduler.step() | ||||
| 
 | ||||
|         # Process epoch results | ||||
|         # ema.update_attr(model) | ||||
|         ema.update_attr(model) | ||||
|         final_epoch = epoch + 1 == epochs | ||||
|         if not opt.notest or final_epoch:  # Calculate mAP | ||||
|             is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 | ||||
|  | @ -313,7 +313,7 @@ def train(): | |||
|                                       data, | ||||
|                                       batch_size=batch_size * 2, | ||||
|                                       img_size=img_size_test, | ||||
|                                       model=model, | ||||
|                                       model=ema.ema, | ||||
|                                       conf_thres=0.001 if final_epoch else 0.01,  # 0.001 for best mAP, 0.01 for speed | ||||
|                                       iou_thres=0.6, | ||||
|                                       save_json=final_epoch and is_coco, | ||||
|  | @ -347,7 +347,7 @@ def train(): | |||
|                 chkpt = {'epoch': epoch, | ||||
|                          'best_fitness': best_fitness, | ||||
|                          'training_results': f.read(), | ||||
|                          'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), | ||||
|                          'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(), | ||||
|                          'optimizer': None if final_epoch else optimizer.state_dict()} | ||||
| 
 | ||||
|             # Save last checkpoint | ||||
|  | @ -377,7 +377,7 @@ def train(): | |||
|         if opt.bucket:  # save to cloud | ||||
|             os.system('gsutil cp %s gs://%s/results' % (fresults, opt.bucket)) | ||||
|             os.system('gsutil cp %s gs://%s/weights' % (wdir + flast, opt.bucket)) | ||||
|             # os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket)) | ||||
|             os.system('gsutil cp %s gs://%s/weights' % (wdir + fbest, opt.bucket)) | ||||
| 
 | ||||
|     if not opt.evolve: | ||||
|         plot_results()  # save as results.png | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| import math | ||||
| import os | ||||
| import time | ||||
| from copy import deepcopy | ||||
|  | @ -139,11 +140,12 @@ class ModelEMA: | |||
|     I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model, decay=0.9998, device=''): | ||||
|     def __init__(self, model, decay=0.9999, device=''): | ||||
|         # make a copy of the model for accumulating moving average of weights | ||||
|         self.ema = deepcopy(model) | ||||
|         self.ema.eval() | ||||
|         self.decay = decay | ||||
|         self.updates = 0  # number of EMA updates | ||||
|         self.decay = lambda x: decay * (1 - math.exp(-x / 1000))  # decay exponential ramp (to help early epochs) | ||||
|         self.device = device  # perform ema on different device from model if set | ||||
|         if device: | ||||
|             self.ema.to(device=device) | ||||
|  | @ -151,7 +153,8 @@ class ModelEMA: | |||
|             p.requires_grad_(False) | ||||
| 
 | ||||
|     def update(self, model): | ||||
|         d = self.decay | ||||
|         self.updates += 1 | ||||
|         d = self.decay(self.updates) | ||||
|         with torch.no_grad(): | ||||
|             if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): | ||||
|                 msd, esd = model.module.state_dict(), self.ema.module.state_dict() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue