forward updated if-else

This commit is contained in:
Glenn Jocher 2020-04-05 13:49:13 -07:00
parent e81a152a92
commit d04738a27c
1 changed files with 4 additions and 3 deletions

View File

@ -234,9 +234,7 @@ class Darknet(nn.Module):
for i, (mdef, module) in enumerate(zip(self.module_defs, self.module_list)): for i, (mdef, module) in enumerate(zip(self.module_defs, self.module_list)):
mtype = mdef['type'] mtype = mdef['type']
if mtype in ['convolutional', 'upsample', 'maxpool']: if mtype == 'shortcut': # sum
x = module(x)
elif mtype == 'shortcut': # sum
if verbose: if verbose:
l = [i - 1] + module.layers # layers l = [i - 1] + module.layers # layers
s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes
@ -259,6 +257,9 @@ class Darknet(nn.Module):
# print(''), [print(out[i].shape) for i in layers], print(x.shape) # print(''), [print(out[i].shape) for i in layers], print(x.shape)
elif mtype == 'yolo': elif mtype == 'yolo':
yolo_out.append(module(x, img_size, out)) yolo_out.append(module(x, img_size, out))
else: # run module directly, i.e. mtype = 'convolutional', 'upsample', 'maxpool', 'batchnorm2d' etc.
x = module(x)
out.append(x if self.routs[i] else []) out.append(x if self.routs[i] else [])
if verbose: if verbose:
print('%g/%g %s -' % (i, len(self.module_list), mtype), list(x.shape), str) print('%g/%g %s -' % (i, len(self.module_list), mtype), list(x.shape), str)