updates
This commit is contained in:
parent
f908f845ae
commit
5ec27663e6
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from utils.parse_config import *
|
||||
from utils.utils import *
|
||||
|
||||
ONNX_EXPORT = False
|
||||
ONNX_EXPORT = True
|
||||
|
||||
|
||||
def create_modules(module_defs):
|
||||
|
@ -331,9 +331,7 @@ class Darknet(nn.Module):
|
|||
self.losses['TC'] = 0
|
||||
|
||||
if ONNX_EXPORT:
|
||||
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet)
|
||||
output = output[1] # first layer reshaped to 85 x 507
|
||||
# output = torch.cat(output, 1)
|
||||
output = torch.cat(output, 1) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647
|
||||
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
|
||||
|
||||
return sum(output) if is_training else torch.cat(output, 1)
|
||||
|
|
Loading…
Reference in New Issue