From f67e1afe3e6cc39cd781b65991a23b1be55090b2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 10 Oct 2019 14:40:18 +0200 Subject: [PATCH] updates --- utils/torch_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 0c98efcd..b631262e 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -77,3 +77,20 @@ def model_info(model, report='summary'): print('%5g %40s %9s %12g %20s %10.3g %10.3g' % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) + + +def load_classifier(name='resnet101', n=2): + # Loads a pretrained model reshaped to n-class output + import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision + model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet') + + # Display model properties + for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']: + print(x + ' =', eval(x)) + + # Reshape output to n classes + filters = model.last_linear.weight.shape[1] + model.last_linear.bias = torch.nn.Parameter(torch.zeros(n)) + model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters)) + model.last_linear.out_features = n + return model