Fuse by default when test.py called directly (faster)
This commit is contained in:
parent
fff45c39a8
commit
89b6377723
|
@ -311,7 +311,7 @@ class Darknet(nn.Module):
|
||||||
|
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
# Fuse Conv2d + BatchNorm2d layers throughout model
|
# Fuse Conv2d + BatchNorm2d layers throughout model
|
||||||
print('Fusing Conv2d() and BatchNorm2d() layers...')
|
print('Fusing layers...')
|
||||||
fused_list = nn.ModuleList()
|
fused_list = nn.ModuleList()
|
||||||
for a in list(self.children())[0]:
|
for a in list(self.children())[0]:
|
||||||
if isinstance(a, nn.Sequential):
|
if isinstance(a, nn.Sequential):
|
||||||
|
|
6
test.py
6
test.py
|
@ -29,7 +29,7 @@ def test(cfg,
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, img_size).to(device)
|
model = Darknet(cfg, img_size)
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
attempt_download(weights)
|
attempt_download(weights)
|
||||||
|
@ -38,6 +38,10 @@ def test(cfg,
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_darknet_weights(model, weights)
|
load_darknet_weights(model, weights)
|
||||||
|
|
||||||
|
# Fuse
|
||||||
|
model.fuse()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
else: # called by train.py
|
else: # called by train.py
|
||||||
|
|
Loading…
Reference in New Issue