diff --git a/models.py b/models.py index a48f35d2..f8ed9449 100755 --- a/models.py +++ b/models.py @@ -67,9 +67,9 @@ def create_modules(module_defs, img_size, arc): # modules = nn.Upsample(scale_factor=1/float(mdef[i+1]['stride']), mode='nearest') # reorg3d elif mdef['type'] == 'shortcut': # nn.Sequential() placeholder for 'shortcut' layer - layer = int(mdef['from']) - filters = output_filters[layer] - routs.extend([i + layer if layer < 0 else layer]) + layers = [int(x) for x in mdef['from'].split(',')] + filters = output_filters[layers[0]] + routs.extend([i + l if l < 0 else l for l in layers]) elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale # torch.Size([16, 128, 104, 104]) @@ -239,10 +239,10 @@ class Darknet(nn.Module): mtype = mdef['type'] if mtype in ['convolutional', 'upsample', 'maxpool']: x = module(x) - elif mtype == 'route': + elif mtype == 'route': # concat layers = [int(x) for x in mdef['layers'].split(',')] if verbose: - print('route concatenating %s' % ([layer_outputs[i].shape for i in layers])) + print('route/concatenate %s' % ([layer_outputs[i].shape for i in layers])) if len(layers) == 1: x = layer_outputs[layers[0]] else: @@ -252,11 +252,12 @@ class Darknet(nn.Module): layer_outputs[layers[1]] = F.interpolate(layer_outputs[layers[1]], scale_factor=[0.5, 0.5]) x = torch.cat([layer_outputs[i] for i in layers], 1) # print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape) - elif mtype == 'shortcut': - j = int(mdef['from']) + elif mtype == 'shortcut': # sum + layers = [int(x) for x in mdef['from'].split(',')] if verbose: - print('shortcut adding layer %g-%s to %g-%s' % (j, layer_outputs[j].shape, i - 1, x.shape)) - x = x + layer_outputs[j] + print('shortcut/add %s' % ([layer_outputs[i].shape for i in layers])) + for j in layers: + x = x + layer_outputs[j] elif mtype == 'yolo': output.append(module(x, img_size)) layer_outputs.append(x if i in self.routs else []) diff --git a/utils/parse_config.py b/utils/parse_config.py index 5d3c20fb..2516388a 100644 --- a/utils/parse_config.py +++ b/utils/parse_config.py @@ -33,7 +33,7 @@ def parse_model_cfg(path): # Check all fields are supported supported = ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups', 'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random', - 'stride_x', 'stride_y'] + 'stride_x', 'stride_y', 'weights_type', 'weights_normalization'] f = [] # fields for x in mdefs[1:]: