This commit is contained in:
Glenn Jocher 2020-02-19 15:16:00 -08:00
parent a9cbc28214
commit 00862e47ef
1 changed files with 15 additions and 15 deletions

View File

@ -254,7 +254,7 @@ class Darknet(nn.Module):
def forward(self, x, var=None):
img_size = x.shape[-2:]
output, layer_outputs = [], []
yolo_out, out = [], []
verbose = False
if verbose:
print('0', x.shape)
@ -265,34 +265,34 @@ class Darknet(nn.Module):
x = module(x)
elif mtype == 'shortcut': # sum
if verbose:
print('shortcut/add %s + %s' % (x.shape, [layer_outputs[i].shape for i in module.layers]))
x = module(x, layer_outputs) # weightedFeatureFusion()
print('shortcut/add %s + %s' % (list(x.shape), [list(out[i].shape) for i in module.layers]))
x = module(x, out) # weightedFeatureFusion()
elif mtype == 'route': # concat
layers = [int(x) for x in mdef['layers'].split(',')]
if verbose:
print('route/concatenate %s' % ([layer_outputs[i].shape for i in layers]))
print('route/concatenate %s + %s' % (list(x.shape), [list(out[i].shape) for i in layers]))
if len(layers) == 1:
x = layer_outputs[layers[0]]
x = out[layers[0]]
else:
try:
x = torch.cat([layer_outputs[i] for i in layers], 1)
x = torch.cat([out[i] for i in layers], 1)
except: # apply stride 2 for darknet reorg layer
layer_outputs[layers[1]] = F.interpolate(layer_outputs[layers[1]], scale_factor=[0.5, 0.5])
x = torch.cat([layer_outputs[i] for i in layers], 1)
# print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape)
out[layers[1]] = F.interpolate(out[layers[1]], scale_factor=[0.5, 0.5])
x = torch.cat([out[i] for i in layers], 1)
# print(''), [print(out[i].shape) for i in layers], print(x.shape)
elif mtype == 'yolo':
output.append(module(x, img_size))
layer_outputs.append(x if i in self.routs else [])
yolo_out.append(module(x, img_size))
out.append(x if i in self.routs else [])
if verbose:
print(i, x.shape)
print(i, list(x.shape))
if self.training: # train
return output
return yolo_out
elif ONNX_EXPORT: # export
x = [torch.cat(x, 0) for x in zip(*output)]
x = [torch.cat(x, 0) for x in zip(*yolo_out)]
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
else: # test
io, p = zip(*output) # inference output, training output
io, p = zip(*yolo_out) # inference output, training output
return torch.cat(io, 1), p
def fuse(self):