diff --git a/models.py b/models.py index de4ad892..f84da507 100755 --- a/models.py +++ b/models.py @@ -74,6 +74,7 @@ def create_modules(module_defs, img_size): layers = mdef['layers'] filters = sum([output_filters[l + 1 if l > 0 else l] for l in layers]) routs.extend([i + l if l < 0 else l for l in layers]) + modules = FeatureConcat(layers=layers) elif mdef['type'] == 'shortcut': # nn.Sequential() placeholder for 'shortcut' layer layers = mdef['from'] @@ -234,27 +235,12 @@ class Darknet(nn.Module): for i, (mdef, module) in enumerate(zip(self.module_defs, self.module_list)): mtype = mdef['type'] - if mtype == 'shortcut': # sum + if mtype in ['shortcut', 'route']: # sum, concat if verbose: l = [i - 1] + module.layers # layers s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, s)]) - x = module(x, out) # WeightedFeatureFusion() - elif mtype == 'route': # concat - layers = mdef['layers'] - if verbose: - l = [i - 1] + layers # layers - s = [list(x.shape)] + [list(out[i].shape) for i in layers] # shapes - str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, s)]) - if len(layers) == 1: - x = out[layers[0]] - else: - try: - x = torch.cat([out[i] for i in layers], 1) - except: # apply stride 2 for darknet reorg layer - out[layers[1]] = F.interpolate(out[layers[1]], scale_factor=[0.5, 0.5]) - x = torch.cat([out[i] for i in layers], 1) - # print(''), [print(out[i].shape) for i in layers], print(x.shape) + x = module(x, out) # WeightedFeatureFusion(), FeatureConcat() elif mtype == 'yolo': yolo_out.append(module(x, img_size, out)) else: # run module directly, i.e. mtype = 'convolutional', 'upsample', 'maxpool', 'batchnorm2d' etc. diff --git a/utils/layers.py b/utils/layers.py index 1f19279d..6424fba7 100644 --- a/utils/layers.py +++ b/utils/layers.py @@ -3,6 +3,16 @@ import torch.nn.functional as F from utils.utils import * +class FeatureConcat(nn.Module): + def __init__(self, layers): + super(FeatureConcat, self).__init__() + self.layers = layers # layer indices + self.multiple = len(layers) > 1 # multiple layers flag + + def forward(self, x, outputs): + return torch.cat([outputs[i] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]] + + class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, layers, weight=False): super(WeightedFeatureFusion, self).__init__()