weight_decay fix
This commit is contained in:
		
							parent
							
								
									ff82e4d488
								
							
						
					
					
						commit
						798a7396f1
					
				
							
								
								
									
										5
									
								
								train.py
								
								
								
								
							
							
						
						
									
										5
									
								
								train.py
								
								
								
								
							|  | @ -261,9 +261,8 @@ def train(): | ||||||
|                 print('WARNING: nan loss detected, ending training') |                 print('WARNING: nan loss detected, ending training') | ||||||
|                 return results |                 return results | ||||||
| 
 | 
 | ||||||
|             # Divide by accumulation count |             # Scale loss by nominal batch_size of 64 | ||||||
|             if accumulate > 1: |             loss *= batch_size / 64 | ||||||
|                 loss /= accumulate |  | ||||||
| 
 | 
 | ||||||
|             # Compute gradient |             # Compute gradient | ||||||
|             if mixed_precision: |             if mixed_precision: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue