updates
This commit is contained in:
		
							parent
							
								
									6ca8277de2
								
							
						
					
					
						commit
						41bf46a419
					
				
							
								
								
									
										2
									
								
								train.py
								
								
								
								
							
							
						
						
									
										2
									
								
								train.py
								
								
								
								
							|  | @ -240,7 +240,7 @@ def train(): | ||||||
|             # Hyperparameter Burn-in |             # Hyperparameter Burn-in | ||||||
|             n_burn = 200  # number of burn-in batches |             n_burn = 200  # number of burn-in batches | ||||||
|             if ni <= n_burn: |             if ni <= n_burn: | ||||||
|                 g = (ni / n_burn) ** 2  # gain |                 g = ni / n_burn  # gain | ||||||
|                 for x in model.named_modules(): |                 for x in model.named_modules(): | ||||||
|                     if x[0].endswith('BatchNorm2d'): |                     if x[0].endswith('BatchNorm2d'): | ||||||
|                         # x[1].momentum = 1 - 0.9 * g  # momentum falls from 1 - 0.1 |                         # x[1].momentum = 1 - 0.9 * g  # momentum falls from 1 - 0.1 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue