FLOPS report
This commit is contained in:
		
							parent
							
								
									60c8d194cd
								
							
						
					
					
						commit
						20454990ce
					
				|  | @ -88,6 +88,11 @@ def model_info(model, verbose=False): | ||||||
|                   (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) |                   (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) | ||||||
|     print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) |     print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) | ||||||
| 
 | 
 | ||||||
|  |     # Report FLOPS | ||||||
|  |     # from thop import profile | ||||||
|  |     # macs, params = profile(model, inputs=(torch.zeros(1, 3, 608, 608),)) | ||||||
|  |     # print('%.3f FLOPS' % (macs / 1E9 * 2)) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def load_classifier(name='resnet101', n=2): | def load_classifier(name='resnet101', n=2): | ||||||
|     # Loads a pretrained model reshaped to n-class output |     # Loads a pretrained model reshaped to n-class output | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue