This commit is contained in:
Glenn Jocher 2018-12-24 13:11:21 +01:00
parent 38fbc1e383
commit 5403581e38
1 changed files with 14 additions and 1 deletions

View File

@ -41,7 +41,8 @@ def create_modules(module_defs):
modules.add_module('maxpool_%d' % i, maxpool)
elif module_def['type'] == 'upsample':
upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest')
# upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated
upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest')
modules.add_module('upsample_%d' % i, upsample)
elif module_def['type'] == 'route':
@ -79,6 +80,18 @@ class EmptyLayer(nn.Module):
super(EmptyLayer, self).__init__()
class Upsample(torch.nn.Module):
# Custom Upsample layer (nn.Upsample gives deprecated warning message)
def __init__(self, scale_factor=1, mode='nearest'):
super(Upsample, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
return torch.nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
class YOLOLayer(nn.Module):
def __init__(self, anchors, nC, img_dim, anchor_idxs, cfg):