This commit is contained in:
Glenn Jocher 2019-05-14 18:37:13 +02:00
parent 4b15644b46
commit cc5660e7c0
1 changed files with 2 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from utils.parse_config import * from utils.parse_config import *
from utils.utils import * from utils.utils import *
ONNX_EXPORT = False ONNX_EXPORT = True
def create_modules(module_defs): def create_modules(module_defs):
@ -111,9 +111,6 @@ class YOLOLayer(nn.Module):
if ONNX_EXPORT: # grids must be computed in __init__ if ONNX_EXPORT: # grids must be computed in __init__
stride = [32, 16, 8][yolo_layer] # stride of this layer stride = [32, 16, 8][yolo_layer] # stride of this layer
if cfg.endswith('yolov3-tiny.cfg'):
stride *= 2
nx = int(img_size[1] / stride) # number x grid points nx = int(img_size[1] / stride) # number x grid points
ny = int(img_size[0] / stride) # number y grid points ny = int(img_size[0] / stride) # number y grid points
create_grids(self, max(img_size), (nx, ny)) create_grids(self, max(img_size), (nx, ny))
@ -215,6 +212,7 @@ class Darknet(nn.Module):
return output return output
elif ONNX_EXPORT: elif ONNX_EXPORT:
output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647 output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647
print(output.shape)
return output[5:85].t(), output[:4].t() # ONNX scores, boxes return output[5:85].t(), output[:4].t() # ONNX scores, boxes
else: else:
io, p = list(zip(*output)) # inference output, training output io, p = list(zip(*output)) # inference output, training output