updates
This commit is contained in:
parent
a9cbc28214
commit
00862e47ef
30
models.py
30
models.py
|
@ -254,7 +254,7 @@ class Darknet(nn.Module):
|
|||
|
||||
def forward(self, x, var=None):
|
||||
img_size = x.shape[-2:]
|
||||
output, layer_outputs = [], []
|
||||
yolo_out, out = [], []
|
||||
verbose = False
|
||||
if verbose:
|
||||
print('0', x.shape)
|
||||
|
@ -265,34 +265,34 @@ class Darknet(nn.Module):
|
|||
x = module(x)
|
||||
elif mtype == 'shortcut': # sum
|
||||
if verbose:
|
||||
print('shortcut/add %s + %s' % (x.shape, [layer_outputs[i].shape for i in module.layers]))
|
||||
x = module(x, layer_outputs) # weightedFeatureFusion()
|
||||
print('shortcut/add %s + %s' % (list(x.shape), [list(out[i].shape) for i in module.layers]))
|
||||
x = module(x, out) # weightedFeatureFusion()
|
||||
elif mtype == 'route': # concat
|
||||
layers = [int(x) for x in mdef['layers'].split(',')]
|
||||
if verbose:
|
||||
print('route/concatenate %s' % ([layer_outputs[i].shape for i in layers]))
|
||||
print('route/concatenate %s + %s' % (list(x.shape), [list(out[i].shape) for i in layers]))
|
||||
if len(layers) == 1:
|
||||
x = layer_outputs[layers[0]]
|
||||
x = out[layers[0]]
|
||||
else:
|
||||
try:
|
||||
x = torch.cat([layer_outputs[i] for i in layers], 1)
|
||||
x = torch.cat([out[i] for i in layers], 1)
|
||||
except: # apply stride 2 for darknet reorg layer
|
||||
layer_outputs[layers[1]] = F.interpolate(layer_outputs[layers[1]], scale_factor=[0.5, 0.5])
|
||||
x = torch.cat([layer_outputs[i] for i in layers], 1)
|
||||
# print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape)
|
||||
out[layers[1]] = F.interpolate(out[layers[1]], scale_factor=[0.5, 0.5])
|
||||
x = torch.cat([out[i] for i in layers], 1)
|
||||
# print(''), [print(out[i].shape) for i in layers], print(x.shape)
|
||||
elif mtype == 'yolo':
|
||||
output.append(module(x, img_size))
|
||||
layer_outputs.append(x if i in self.routs else [])
|
||||
yolo_out.append(module(x, img_size))
|
||||
out.append(x if i in self.routs else [])
|
||||
if verbose:
|
||||
print(i, x.shape)
|
||||
print(i, list(x.shape))
|
||||
|
||||
if self.training: # train
|
||||
return output
|
||||
return yolo_out
|
||||
elif ONNX_EXPORT: # export
|
||||
x = [torch.cat(x, 0) for x in zip(*output)]
|
||||
x = [torch.cat(x, 0) for x in zip(*yolo_out)]
|
||||
return x[0], torch.cat(x[1:3], 1) # scores, boxes: 3780x80, 3780x4
|
||||
else: # test
|
||||
io, p = zip(*output) # inference output, training output
|
||||
io, p = zip(*yolo_out) # inference output, training output
|
||||
return torch.cat(io, 1), p
|
||||
|
||||
def fuse(self):
|
||||
|
|
Loading…
Reference in New Issue