updates
This commit is contained in:
parent
38fbc1e383
commit
5403581e38
15
models.py
15
models.py
|
@ -41,7 +41,8 @@ def create_modules(module_defs):
|
|||
modules.add_module('maxpool_%d' % i, maxpool)
|
||||
|
||||
elif module_def['type'] == 'upsample':
|
||||
upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest')
|
||||
# upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated
|
||||
upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest')
|
||||
modules.add_module('upsample_%d' % i, upsample)
|
||||
|
||||
elif module_def['type'] == 'route':
|
||||
|
@ -79,6 +80,18 @@ class EmptyLayer(nn.Module):
|
|||
super(EmptyLayer, self).__init__()
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
# Custom Upsample layer (nn.Upsample gives deprecated warning message)
|
||||
|
||||
def __init__(self, scale_factor=1, mode='nearest'):
|
||||
super(Upsample, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
|
||||
|
||||
|
||||
class YOLOLayer(nn.Module):
|
||||
|
||||
def __init__(self, anchors, nC, img_dim, anchor_idxs, cfg):
|
||||
|
|
Loading…
Reference in New Issue