updates
This commit is contained in:
parent
8646db7c19
commit
e919635564
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = True
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs):
|
def create_modules(module_defs):
|
||||||
|
@ -146,7 +146,7 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
def forward(self, p, targets=None, var=None):
|
def forward(self, p, targets=None, var=None):
|
||||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||||
bs = p.shape[0] # batch size
|
bs = 1 if ONNX_EXPORT else p.shape[0] # batch size
|
||||||
nG = self.nG # number of grid points
|
nG = self.nG # number of grid points
|
||||||
|
|
||||||
if p.is_cuda and not self.weights.is_cuda:
|
if p.is_cuda and not self.weights.is_cuda:
|
||||||
|
@ -299,7 +299,7 @@ def load_darknet_weights(self, weights, cutoff=-1):
|
||||||
# Try to download weights if not available locally
|
# Try to download weights if not available locally
|
||||||
if not os.path.isfile(weights):
|
if not os.path.isfile(weights):
|
||||||
try:
|
try:
|
||||||
os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights)
|
os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -O ' + weights)
|
||||||
except IOError:
|
except IOError:
|
||||||
print(weights + ' not found')
|
print(weights + ' not found')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue