From 4af819449c9cfd2ad6822c4111422d8dc73257e5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Jul 2019 12:06:29 +0200 Subject: [PATCH] updates --- models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index 949e1d53..16c2c243 100755 --- a/models.py +++ b/models.py @@ -1,6 +1,8 @@ from utils.parse_config import * from utils.utils import * +import torch.nn.functional as F + ONNX_EXPORT = False @@ -49,11 +51,19 @@ def create_modules(module_defs): layers = [int(x) for x in module_def['layers'].split(',')] filters = sum([output_filters[i + 1 if i > 0 else i] for i in layers]) modules.add_module('route_%d' % i, EmptyLayer()) + # if module_defs[i+1]['type'] == 'reorg3d': + # upsample = nn.Upsample(scale_factor=1/float(module_defs[i+1]['stride']), mode='nearest') + # modules.add_module('reorg3d_%d' % i, upsample) elif module_def['type'] == 'shortcut': filters = output_filters[int(module_def['from'])] modules.add_module('shortcut_%d' % i, EmptyLayer()) + elif module_def['type'] == 'reorg3d': + # torch.Size([16, 128, 104, 104]) + # torch.Size([16, 64, 208, 208]) <-- # stride 2 interpolate dimensions 2 and 3 to cat with prior layer + pass + elif module_def['type'] == 'yolo': yolo_index += 1 anchor_idxs = [int(x) for x in module_def['mask'].split(',')] @@ -186,7 +196,12 @@ class Darknet(nn.Module): if len(layer_i) == 1: x = layer_outputs[layer_i[0]] else: - x = torch.cat([layer_outputs[i] for i in layer_i], 1) + try: + x = torch.cat([layer_outputs[i] for i in layer_i], 1) + except: # apply stride 2 for darknet reorg layer + layer_outputs[layer_i[1]] = F.interpolate(layer_outputs[layer_i[1]], scale_factor=[0.5, 0.5]) + x = torch.cat([layer_outputs[i] for i in layer_i], 1) + # print(''), [print(layer_outputs[i].shape) for i in layer_i], print(x.shape) elif mtype == 'shortcut': layer_i = int(module_def['from']) x = layer_outputs[-1] + layer_outputs[layer_i]