From e9196355647d957e6a39148d4a20e270ebd2d469 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 17 Feb 2019 18:02:56 +0100 Subject: [PATCH] updates --- models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index 181950f9..a7ffc4fa 100755 --- a/models.py +++ b/models.py @@ -6,7 +6,7 @@ import torch.nn as nn from utils.parse_config import * from utils.utils import * -ONNX_EXPORT = False +ONNX_EXPORT = True def create_modules(module_defs): @@ -146,7 +146,7 @@ class YOLOLayer(nn.Module): def forward(self, p, targets=None, var=None): FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor - bs = p.shape[0] # batch size + bs = 1 if ONNX_EXPORT else p.shape[0] # batch size nG = self.nG # number of grid points if p.is_cuda and not self.weights.is_cuda: @@ -299,7 +299,7 @@ def load_darknet_weights(self, weights, cutoff=-1): # Try to download weights if not available locally if not os.path.isfile(weights): try: - os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -P ' + weights) + os.system('wget https://pjreddie.com/media/files/' + weights_file + ' -O ' + weights) except IOError: print(weights + ' not found')