updates
This commit is contained in:
parent
981abf679c
commit
4af819449c
15
models.py
15
models.py
|
@ -1,6 +1,8 @@
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,11 +51,19 @@ def create_modules(module_defs):
|
||||||
layers = [int(x) for x in module_def['layers'].split(',')]
|
layers = [int(x) for x in module_def['layers'].split(',')]
|
||||||
filters = sum([output_filters[i + 1 if i > 0 else i] for i in layers])
|
filters = sum([output_filters[i + 1 if i > 0 else i] for i in layers])
|
||||||
modules.add_module('route_%d' % i, EmptyLayer())
|
modules.add_module('route_%d' % i, EmptyLayer())
|
||||||
|
# if module_defs[i+1]['type'] == 'reorg3d':
|
||||||
|
# upsample = nn.Upsample(scale_factor=1/float(module_defs[i+1]['stride']), mode='nearest')
|
||||||
|
# modules.add_module('reorg3d_%d' % i, upsample)
|
||||||
|
|
||||||
elif module_def['type'] == 'shortcut':
|
elif module_def['type'] == 'shortcut':
|
||||||
filters = output_filters[int(module_def['from'])]
|
filters = output_filters[int(module_def['from'])]
|
||||||
modules.add_module('shortcut_%d' % i, EmptyLayer())
|
modules.add_module('shortcut_%d' % i, EmptyLayer())
|
||||||
|
|
||||||
|
elif module_def['type'] == 'reorg3d':
|
||||||
|
# torch.Size([16, 128, 104, 104])
|
||||||
|
# torch.Size([16, 64, 208, 208]) <-- # stride 2 interpolate dimensions 2 and 3 to cat with prior layer
|
||||||
|
pass
|
||||||
|
|
||||||
elif module_def['type'] == 'yolo':
|
elif module_def['type'] == 'yolo':
|
||||||
yolo_index += 1
|
yolo_index += 1
|
||||||
anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
|
anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
|
||||||
|
@ -186,7 +196,12 @@ class Darknet(nn.Module):
|
||||||
if len(layer_i) == 1:
|
if len(layer_i) == 1:
|
||||||
x = layer_outputs[layer_i[0]]
|
x = layer_outputs[layer_i[0]]
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
x = torch.cat([layer_outputs[i] for i in layer_i], 1)
|
x = torch.cat([layer_outputs[i] for i in layer_i], 1)
|
||||||
|
except: # apply stride 2 for darknet reorg layer
|
||||||
|
layer_outputs[layer_i[1]] = F.interpolate(layer_outputs[layer_i[1]], scale_factor=[0.5, 0.5])
|
||||||
|
x = torch.cat([layer_outputs[i] for i in layer_i], 1)
|
||||||
|
# print(''), [print(layer_outputs[i].shape) for i in layer_i], print(x.shape)
|
||||||
elif mtype == 'shortcut':
|
elif mtype == 'shortcut':
|
||||||
layer_i = int(module_def['from'])
|
layer_i = int(module_def['from'])
|
||||||
x = layer_outputs[-1] + layer_outputs[layer_i]
|
x = layer_outputs[-1] + layer_outputs[layer_i]
|
||||||
|
|
Loading…
Reference in New Issue