FLOPS report
This commit is contained in:
parent
ea80ba65af
commit
4ac60018f6
|
@ -87,12 +87,15 @@ def model_info(model, verbose=False):
|
||||||
name = name.replace('module_list.', '')
|
name = name.replace('module_list.', '')
|
||||||
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g))
|
|
||||||
|
|
||||||
# FLOPs
|
try: # FLOPS
|
||||||
# from thop import profile
|
from thop import profile
|
||||||
# macs, params = profile(model, inputs=(torch.zeros(1, 3, 640, 640),))
|
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 640, 640),))
|
||||||
# print('%.3f GFLOPs' % (macs / 1E9 * 2))
|
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
|
||||||
|
except:
|
||||||
|
fs = ''
|
||||||
|
|
||||||
|
print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
|
||||||
|
|
||||||
|
|
||||||
def load_classifier(name='resnet101', n=2):
|
def load_classifier(name='resnet101', n=2):
|
||||||
|
|
Loading…
Reference in New Issue