This commit is contained in:
Glenn Jocher 2019-12-03 15:34:20 -08:00
parent fcdbd3ee35
commit cae901c2da
1 changed files with 11 additions and 0 deletions

View File

@ -51,6 +51,7 @@ def create_modules(module_defs, img_size, arc):
elif mdef['type'] == 'upsample':
modules = nn.Upsample(scale_factor=int(mdef['stride']), mode='nearest')
# modules = Upsample(scale_factor=int(mdef['stride']))
elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer
layers = [int(x) for x in mdef['layers'].split(',')]
@ -141,6 +142,16 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish
return x.mul_(F.softplus(x).tanh())
class Upsample(nn.Module):
def __init__(self, scale_factor):
super(Upsample, self).__init__()
self.scale = scale_factor
def forward(self, x):
h, w = x.shape[2:]
return F.interpolate(x, size=(int(h * self.scale), int(w * self.scale)))
class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_index, arc):
super(YOLOLayer, self).__init__()