This commit is contained in:
Glenn Jocher 2019-02-17 18:02:56 +01:00
parent 8646db7c19
commit e919635564
1 changed files with 3 additions and 3 deletions

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from utils.parse_config import * from utils.parse_config import *
from utils.utils import * from utils.utils import *
ONNX_EXPORT = False ONNX_EXPORT = True
def create_modules(module_defs): def create_modules(module_defs):
@ -146,7 +146,7 @@ class YOLOLayer(nn.Module):
def forward(self, p, targets=None, var=None): def forward(self, p, targets=None, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size bs = 1 if ONNX_EXPORT else p.shape[0] # batch size
nG = self.nG # number of grid points nG = self.nG # number of grid points
if p.is_cuda and not self.weights.is_cuda: if p.is_cuda and not self.weights.is_cuda:
@ -299,7 +299,7 @@ def load_darknet_weights(self, weights, cutoff=-1):
# Try to download weights if not available locally # Try to download weights if not available locally
if not os.path.isfile(weights): if not os.path.isfile(weights):
try: try:
os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights) os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -O ' + weights)
except IOError: except IOError:
print(weights + ' not found') print(weights + ' not found')