updates
This commit is contained in:
parent
f908f845ae
commit
5ec27663e6
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = True
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs):
|
def create_modules(module_defs):
|
||||||
|
@ -331,9 +331,7 @@ class Darknet(nn.Module):
|
||||||
self.losses['TC'] = 0
|
self.losses['TC'] = 0
|
||||||
|
|
||||||
if ONNX_EXPORT:
|
if ONNX_EXPORT:
|
||||||
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet)
|
output = torch.cat(output, 1) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647
|
||||||
output = output[1] # first layer reshaped to 85 x 507
|
|
||||||
# output = torch.cat(output, 1)
|
|
||||||
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
|
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
|
||||||
|
|
||||||
return sum(output) if is_training else torch.cat(output, 1)
|
return sum(output) if is_training else torch.cat(output, 1)
|
||||||
|
|
Loading…
Reference in New Issue