memory-saving routs update

This commit is contained in:
Glenn Jocher 2019-08-12 13:37:11 +02:00
parent daaa8194a9
commit 4ac6e88ea9
1 changed files with 9 additions and 6 deletions

View File

@ -14,9 +14,10 @@ def create_modules(module_defs):
hyperparams = module_defs.pop(0) hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams['channels'])] output_filters = [int(hyperparams['channels'])]
module_list = nn.ModuleList() module_list = nn.ModuleList()
routs = [] # list of layers which rout to deeper layes
yolo_index = -1 yolo_index = -1
for mdef in module_defs: for i, mdef in enumerate(module_defs):
modules = nn.Sequential() modules = nn.Sequential()
if mdef['type'] == 'convolutional': if mdef['type'] == 'convolutional':
@ -53,11 +54,14 @@ def create_modules(module_defs):
elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer
layers = [int(x) for x in mdef['layers'].split(',')] layers = [int(x) for x in mdef['layers'].split(',')]
filters = sum([output_filters[i + 1 if i > 0 else i] for i in layers]) filters = sum([output_filters[i + 1 if i > 0 else i] for i in layers])
routs.extend([l if l > 0 else l + i for l in layers])
# if mdef[i+1]['type'] == 'reorg3d': # if mdef[i+1]['type'] == 'reorg3d':
# modules = nn.Upsample(scale_factor=1/float(mdef[i+1]['stride']), mode='nearest') # reorg3d # modules = nn.Upsample(scale_factor=1/float(mdef[i+1]['stride']), mode='nearest') # reorg3d
elif mdef['type'] == 'shortcut': # nn.Sequential() placeholder for 'shortcut' layer elif mdef['type'] == 'shortcut': # nn.Sequential() placeholder for 'shortcut' layer
filters = output_filters[int(mdef['from'])] filters = output_filters[int(mdef['from'])]
layer = int(mdef['from'])
routs.extend([i + layer if layer < 0 else layer])
elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale
# torch.Size([16, 128, 104, 104]) # torch.Size([16, 128, 104, 104])
@ -80,7 +84,7 @@ def create_modules(module_defs):
module_list.append(modules) module_list.append(modules)
output_filters.append(filters) output_filters.append(filters)
return hyperparams, module_list return module_list, routs
class Swish(nn.Module): class Swish(nn.Module):
@ -171,9 +175,8 @@ class Darknet(nn.Module):
super(Darknet, self).__init__() super(Darknet, self).__init__()
self.module_defs = parse_model_cfg(cfg) self.module_defs = parse_model_cfg(cfg)
self.module_defs[0]['cfg'] = cfg
self.module_defs[0]['height'] = img_size self.module_defs[0]['height'] = img_size
self.hyperparams, self.module_list = create_modules(self.module_defs) self.module_list, self.routs = create_modules(self.module_defs)
self.yolo_layers = get_yolo_layers(self) self.yolo_layers = get_yolo_layers(self)
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
@ -201,11 +204,11 @@ class Darknet(nn.Module):
x = torch.cat([layer_outputs[i] for i in layer_i], 1) x = torch.cat([layer_outputs[i] for i in layer_i], 1)
# print(''), [print(layer_outputs[i].shape) for i in layer_i], print(x.shape) # print(''), [print(layer_outputs[i].shape) for i in layer_i], print(x.shape)
elif mtype == 'shortcut': elif mtype == 'shortcut':
x = layer_outputs[-1] + layer_outputs[int(mdef['from']) ] x = x + layer_outputs[int(mdef['from'])]
elif mtype == 'yolo': elif mtype == 'yolo':
x = module(x, img_size) x = module(x, img_size)
output.append(x) output.append(x)
layer_outputs.append(x) layer_outputs.append(x if i in self.routs else [])
if self.training: if self.training:
return output return output