From 5403581e382a27a7b9ee3a5ec96e1f02dbb5a6ff Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 24 Dec 2018 13:11:21 +0100 Subject: [PATCH] updates --- models.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index ed7f2ded..ecbaa341 100755 --- a/models.py +++ b/models.py @@ -41,7 +41,8 @@ def create_modules(module_defs): modules.add_module('maxpool_%d' % i, maxpool) elif module_def['type'] == 'upsample': - upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') + # upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated + upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest') modules.add_module('upsample_%d' % i, upsample) elif module_def['type'] == 'route': @@ -79,6 +80,18 @@ class EmptyLayer(nn.Module): super(EmptyLayer, self).__init__() +class Upsample(torch.nn.Module): + # Custom Upsample layer (nn.Upsample gives deprecated warning message) + + def __init__(self, scale_factor=1, mode='nearest'): + super(Upsample, self).__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + return torch.nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) + + class YOLOLayer(nn.Module): def __init__(self, anchors, nC, img_dim, anchor_idxs, cfg):