From a8f443518f8895824c4f528195f568ed2d2bdd8e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 5 May 2019 13:44:12 +0200 Subject: [PATCH] updates --- utils/utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 156da436..6a33e5f2 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -37,16 +37,17 @@ def load_classes(path): return list(filter(None, names)) # filter removes empty strings (such as last line) -def model_info(model): +def model_info(model, report='summary'): # Plots a line-by-line description of a PyTorch model n_p = sum(x.numel() for x in model.parameters()) # number parameters n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients - print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) - for i, (name, p) in enumerate(model.named_parameters()): - name = name.replace('module_list.', '') - print('%5g %40s %9s %12g %20s %10.3g %10.3g' % ( - i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) - print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g)) + if report is 'full': + print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) + for i, (name, p) in enumerate(model.named_parameters()): + name = name.replace('module_list.', '') + print('%5g %40s %9s %12g %20s %10.3g %10.3g' % + (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) + print('Model Summary: %g layers, %g parameters, %g gradients' % (len(list(model.parameters())), n_p, n_g)) def labels_to_class_weights(labels, nc=80): @@ -61,12 +62,12 @@ def labels_to_class_weights(labels, nc=80): def coco_class_weights(): # frequency of each class in coco train2014 - weights = 1 / torch.FloatTensor( - [187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671, + n = [187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671, 6769, 5706, 3908, 903, 3686, 3596, 6200, 7920, 8779, 4505, 4272, 1862, 4698, 1962, 4403, 6659, 2402, 2689, 4012, 4175, 3411, 17048, 5637, 14553, 3923, 5539, 4289, 10084, 7018, 4314, 3099, 4638, 4939, 5543, 2038, 4004, 5053, 4578, 27292, 4113, 5931, 2905, 11174, 2873, 4036, 3415, 1517, 4122, 1980, 4464, 1190, 2302, 156, 3933, - 1877, 17630, 4337, 4624, 1075, 3468, 135, 1380]) + 1877, 17630, 4337, 4624, 1075, 3468, 135, 1380] + weights = 1 / torch.Tensor(n) weights /= weights.sum() return weights