diff --git a/utils/torch_utils.py b/utils/torch_utils.py index ac38249c..18664b3a 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -88,6 +88,11 @@ def model_info(model, verbose=False): (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) + # Report FLOPS + # from thop import profile + # macs, params = profile(model, inputs=(torch.zeros(1, 3, 608, 608),)) + # print('%.3f FLOPS' % (macs / 1E9 * 2)) + def load_classifier(name='resnet101', n=2): # Loads a pretrained model reshaped to n-class output