FLOPS report
This commit is contained in:
		
							parent
							
								
									ea80ba65af
								
							
						
					
					
						commit
						4ac60018f6
					
				|  | @ -87,12 +87,15 @@ def model_info(model, verbose=False): | |||
|             name = name.replace('module_list.', '') | ||||
|             print('%5g %40s %9s %12g %20s %10.3g %10.3g' % | ||||
|                   (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) | ||||
|     print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) | ||||
| 
 | ||||
|     # FLOPs | ||||
|     # from thop import profile | ||||
|     # macs, params = profile(model, inputs=(torch.zeros(1, 3, 640, 640),)) | ||||
|     # print('%.3f GFLOPs' % (macs / 1E9 * 2)) | ||||
|     try:  # FLOPS | ||||
|         from thop import profile | ||||
|         macs, _ = profile(model, inputs=(torch.zeros(1, 3, 640, 640),)) | ||||
|         fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) | ||||
|     except: | ||||
|         fs = '' | ||||
| 
 | ||||
|     print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs)) | ||||
| 
 | ||||
| 
 | ||||
| def load_classifier(name='resnet101', n=2): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue