updates
This commit is contained in:
parent
ee319aeefd
commit
f67e1afe3e
|
@ -77,3 +77,20 @@ def model_info(model, report='summary'):
|
||||||
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g))
|
print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g))
|
||||||
|
|
||||||
|
|
||||||
|
def load_classifier(name='resnet101', n=2):
|
||||||
|
# Loads a pretrained model reshaped to n-class output
|
||||||
|
import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
|
||||||
|
model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
|
||||||
|
|
||||||
|
# Display model properties
|
||||||
|
for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
|
||||||
|
print(x + ' =', eval(x))
|
||||||
|
|
||||||
|
# Reshape output to n classes
|
||||||
|
filters = model.last_linear.weight.shape[1]
|
||||||
|
model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
|
||||||
|
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||||
|
model.last_linear.out_features = n
|
||||||
|
return model
|
||||||
|
|
Loading…
Reference in New Issue