This commit is contained in:
Glenn Jocher 2019-02-09 22:38:51 +01:00
parent f908f845ae
commit 5ec27663e6
1 changed files with 2 additions and 4 deletions

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from utils.parse_config import * from utils.parse_config import *
from utils.utils import * from utils.utils import *
ONNX_EXPORT = False ONNX_EXPORT = True
def create_modules(module_defs): def create_modules(module_defs):
@ -331,9 +331,7 @@ class Darknet(nn.Module):
self.losses['TC'] = 0 self.losses['TC'] = 0
if ONNX_EXPORT: if ONNX_EXPORT:
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet) output = torch.cat(output, 1) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647
output = output[1] # first layer reshaped to 85 x 507
# output = torch.cat(output, 1)
return output[5:85].t(), output[:4].t() # ONNX scores, boxes return output[5:85].t(), output[:4].t() # ONNX scores, boxes
return sum(output) if is_training else torch.cat(output, 1) return sum(output) if is_training else torch.cat(output, 1)