From 4ac60018f6e6c1e24b496485f126a660d9c793d8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 1 Apr 2020 14:05:41 -0700 Subject: [PATCH] FLOPS report --- utils/torch_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e4490a70..94ca9ded 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -87,12 +87,15 @@ def model_info(model, verbose=False): name = name.replace('module_list.', '') print('%5g %40s %9s %12g %20s %10.3g %10.3g' % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) - print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) - # FLOPs - # from thop import profile - # macs, params = profile(model, inputs=(torch.zeros(1, 3, 640, 640),)) - # print('%.3f GFLOPs' % (macs / 1E9 * 2)) + try: # FLOPS + from thop import profile + macs, _ = profile(model, inputs=(torch.zeros(1, 3, 640, 640),)) + fs = ', %.1f GFLOPS' % (macs / 1E9 * 2) + except: + fs = '' + + print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs)) def load_classifier(name='resnet101', n=2):