model.info() method implemented

This commit is contained in:
Glenn Jocher 2020-03-14 16:46:54 -07:00
parent b89cc396af
commit 9ce4ec48a7
4 changed files with 8 additions and 5 deletions

View File

@ -36,7 +36,6 @@ def detect(save_img=False):
# Fuse Conv2d + BatchNorm2d layers
# model.fuse()
# torch_utils.model_info(model, report='summary') # 'full' or 'summary'
# Eval mode
model.to(device).eval()

View File

@ -240,6 +240,7 @@ class Darknet(nn.Module):
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
self.info() # print model description
def forward(self, x, verbose=False):
img_size = x.shape[-2:]
@ -291,6 +292,7 @@ class Darknet(nn.Module):
def fuse(self):
# Fuse Conv2d + BatchNorm2d layers throughout model
print('Fusing Conv2d() and BatchNorm2d() layers...')
fused_list = nn.ModuleList()
for a in list(self.children())[0]:
if isinstance(a, nn.Sequential):
@ -303,7 +305,10 @@ class Darknet(nn.Module):
break
fused_list.append(a)
self.module_list = fused_list
# model_info(self) # yolov3-spp reduced from 225 to 152 layers
self.info() # yolov3-spp reduced from 225 to 152 layers
def info(self, verbose=False):
torch_utils.model_info(self, verbose)
def get_yolo_layers(model):

View File

@ -207,7 +207,6 @@ def train():
# torch.autograd.set_detect_anomaly(True)
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
t0 = time.time()
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
print('Using %g dataloader workers' % nw)
print('Starting training for %g epochs...' % epochs)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------

View File

@ -75,11 +75,11 @@ def fuse_conv_and_bn(conv, bn):
return fusedconv
def model_info(model, report='summary'):
def model_info(model, verbose=False):
# Plots a line-by-line description of a PyTorch model
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
if report == 'full':
if verbose:
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')