updates
This commit is contained in:
parent
5e747f8da9
commit
af8af1ce68
11
models.py
11
models.py
|
@ -51,7 +51,6 @@ def create_modules(module_defs, img_size, arc):
|
||||||
|
|
||||||
elif mdef['type'] == 'upsample':
|
elif mdef['type'] == 'upsample':
|
||||||
modules = nn.Upsample(scale_factor=int(mdef['stride']), mode='nearest')
|
modules = nn.Upsample(scale_factor=int(mdef['stride']), mode='nearest')
|
||||||
# modules = Upsample(scale_factor=int(mdef['stride']))
|
|
||||||
|
|
||||||
elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer
|
elif mdef['type'] == 'route': # nn.Sequential() placeholder for 'route' layer
|
||||||
layers = [int(x) for x in mdef['layers'].split(',')]
|
layers = [int(x) for x in mdef['layers'].split(',')]
|
||||||
|
@ -142,16 +141,6 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
return x.mul_(F.softplus(x).tanh())
|
return x.mul_(F.softplus(x).tanh())
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, scale_factor):
|
|
||||||
super(Upsample, self).__init__()
|
|
||||||
self.scale = scale_factor
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h, w = x.shape[2:]
|
|
||||||
return F.interpolate(x, size=(int(h * self.scale), int(w * self.scale)))
|
|
||||||
|
|
||||||
|
|
||||||
class YOLOLayer(nn.Module):
|
class YOLOLayer(nn.Module):
|
||||||
def __init__(self, anchors, nc, img_size, yolo_index, arc):
|
def __init__(self, anchors, nc, img_size, yolo_index, arc):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
|
Loading…
Reference in New Issue