updates
This commit is contained in:
parent
4b15644b46
commit
cc5660e7c0
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = True
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs):
|
def create_modules(module_defs):
|
||||||
|
@ -111,9 +111,6 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
if ONNX_EXPORT: # grids must be computed in __init__
|
if ONNX_EXPORT: # grids must be computed in __init__
|
||||||
stride = [32, 16, 8][yolo_layer] # stride of this layer
|
stride = [32, 16, 8][yolo_layer] # stride of this layer
|
||||||
if cfg.endswith('yolov3-tiny.cfg'):
|
|
||||||
stride *= 2
|
|
||||||
|
|
||||||
nx = int(img_size[1] / stride) # number x grid points
|
nx = int(img_size[1] / stride) # number x grid points
|
||||||
ny = int(img_size[0] / stride) # number y grid points
|
ny = int(img_size[0] / stride) # number y grid points
|
||||||
create_grids(self, max(img_size), (nx, ny))
|
create_grids(self, max(img_size), (nx, ny))
|
||||||
|
@ -215,6 +212,7 @@ class Darknet(nn.Module):
|
||||||
return output
|
return output
|
||||||
elif ONNX_EXPORT:
|
elif ONNX_EXPORT:
|
||||||
output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647
|
output = torch.cat(output, 1) # cat 3 layers 85 x (507, 2028, 8112) to 85 x 10647
|
||||||
|
print(output.shape)
|
||||||
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
|
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
|
||||||
else:
|
else:
|
||||||
io, p = list(zip(*output)) # inference output, training output
|
io, p = list(zip(*output)) # inference output, training output
|
||||||
|
|
Loading…
Reference in New Issue