diff --git a/models.py b/models.py index 6401ab6c..17ff234f 100755 --- a/models.py +++ b/models.py @@ -51,6 +51,7 @@ def create_modules(module_defs, img_size, arc): elif mdef['type'] == 'upsample': modules = nn.Upsample(scale_factor=int(mdef['stride']), mode='nearest') + # modules = Upsample(scale_factor=int(mdef['stride'])) elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer layers = [int(x) for x in mdef['layers'].split(',')] @@ -141,6 +142,16 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish return x.mul_(F.softplus(x).tanh()) +class Upsample(nn.Module): + def __init__(self, scale_factor): + super(Upsample, self).__init__() + self.scale = scale_factor + + def forward(self, x): + h, w = x.shape[2:] + return F.interpolate(x, size=(int(h * self.scale), int(w * self.scale))) + + class YOLOLayer(nn.Module): def __init__(self, anchors, nc, img_size, yolo_index, arc): super(YOLOLayer, self).__init__()