model.info() method implemented
This commit is contained in:
parent
b89cc396af
commit
9ce4ec48a7
|
@ -36,7 +36,6 @@ def detect(save_img=False):
|
|||
|
||||
# Fuse Conv2d + BatchNorm2d layers
|
||||
# model.fuse()
|
||||
# torch_utils.model_info(model, report='summary') # 'full' or 'summary'
|
||||
|
||||
# Eval mode
|
||||
model.to(device).eval()
|
||||
|
|
|
@ -240,6 +240,7 @@ class Darknet(nn.Module):
|
|||
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
|
||||
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
|
||||
self.info() # print model description
|
||||
|
||||
def forward(self, x, verbose=False):
|
||||
img_size = x.shape[-2:]
|
||||
|
@ -291,6 +292,7 @@ class Darknet(nn.Module):
|
|||
|
||||
def fuse(self):
|
||||
# Fuse Conv2d + BatchNorm2d layers throughout model
|
||||
print('Fusing Conv2d() and BatchNorm2d() layers...')
|
||||
fused_list = nn.ModuleList()
|
||||
for a in list(self.children())[0]:
|
||||
if isinstance(a, nn.Sequential):
|
||||
|
@ -303,7 +305,10 @@ class Darknet(nn.Module):
|
|||
break
|
||||
fused_list.append(a)
|
||||
self.module_list = fused_list
|
||||
# model_info(self) # yolov3-spp reduced from 225 to 152 layers
|
||||
self.info() # yolov3-spp reduced from 225 to 152 layers
|
||||
|
||||
def info(self, verbose=False):
|
||||
torch_utils.model_info(self, verbose)
|
||||
|
||||
|
||||
def get_yolo_layers(model):
|
||||
|
|
1
train.py
1
train.py
|
@ -207,7 +207,6 @@ def train():
|
|||
# torch.autograd.set_detect_anomaly(True)
|
||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||
t0 = time.time()
|
||||
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
|
||||
print('Using %g dataloader workers' % nw)
|
||||
print('Starting training for %g epochs...' % epochs)
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
|
|
|
@ -75,11 +75,11 @@ def fuse_conv_and_bn(conv, bn):
|
|||
return fusedconv
|
||||
|
||||
|
||||
def model_info(model, report='summary'):
|
||||
def model_info(model, verbose=False):
|
||||
# Plots a line-by-line description of a PyTorch model
|
||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
||||
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
||||
if report == 'full':
|
||||
if verbose:
|
||||
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
|
|
Loading…
Reference in New Issue