diff --git a/models.py b/models.py index ed7f2ded..ecbaa341 100755 --- a/models.py +++ b/models.py @@ -41,7 +41,8 @@ def create_modules(module_defs): modules.add_module('maxpool_%d' % i, maxpool) elif module_def['type'] == 'upsample': - upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') + # upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated + upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest') modules.add_module('upsample_%d' % i, upsample) elif module_def['type'] == 'route': @@ -79,6 +80,18 @@ class EmptyLayer(nn.Module): super(EmptyLayer, self).__init__() +class Upsample(torch.nn.Module): + # Custom Upsample layer (nn.Upsample gives deprecated warning message) + + def __init__(self, scale_factor=1, mode='nearest'): + super(Upsample, self).__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + return torch.nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) + + class YOLOLayer(nn.Module): def __init__(self, anchors, nC, img_dim, anchor_idxs, cfg):