updates
This commit is contained in:
parent
4a4668224b
commit
14e4519620
|
@ -55,11 +55,13 @@ def detect(
|
||||||
t = time.time()
|
t = time.time()
|
||||||
save_path = str(Path(output) / Path(path).name)
|
save_path = str(Path(output) / Path(path).name)
|
||||||
|
|
||||||
|
if ONNX_EXPORT:
|
||||||
|
img = torch.zeros((1, 3, 416, 416))
|
||||||
|
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
|
||||||
|
return
|
||||||
|
|
||||||
# Get detections
|
# Get detections
|
||||||
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
||||||
if ONNX_EXPORT:
|
|
||||||
torch.onnx.export(model, img, 'weights/model.onnx', verbose=True)
|
|
||||||
return
|
|
||||||
pred, _ = model(img)
|
pred, _ = model(img)
|
||||||
detections = non_max_suppression(pred, conf_thres, nms_thres)[0]
|
detections = non_max_suppression(pred, conf_thres, nms_thres)[0]
|
||||||
|
|
||||||
|
|
14
models.py
14
models.py
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = True
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs):
|
def create_modules(module_defs):
|
||||||
|
@ -234,18 +234,20 @@ def get_yolo_layers(model):
|
||||||
|
|
||||||
|
|
||||||
def create_grids(self, img_size, ng, device='cpu'):
|
def create_grids(self, img_size, ng, device='cpu'):
|
||||||
|
nx, ny = ng, ng # x and y grid size
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.stride = img_size / ng
|
self.stride = img_size / nx
|
||||||
|
|
||||||
# build xy offsets
|
# build xy offsets
|
||||||
grid_x = torch.arange(ng).repeat((ng, 1)).view((1, 1, ng, ng)).float()
|
yv, xv = torch.meshgrid([torch.arange(nx), torch.arange(ny)])
|
||||||
grid_y = grid_x.permute(0, 1, 3, 2)
|
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, nx, ny, 2))
|
||||||
self.grid_xy = torch.stack((grid_x, grid_y), 4).to(device)
|
|
||||||
|
|
||||||
# build wh gains
|
# build wh gains
|
||||||
self.anchor_vec = self.anchors.to(device) / self.stride
|
self.anchor_vec = self.anchors.to(device) / self.stride
|
||||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
|
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
|
||||||
self.ng = torch.FloatTensor([ng]).to(device)
|
self.ng = torch.Tensor([ng]).to(device)
|
||||||
|
self.nx = nx
|
||||||
|
self.ny = ny
|
||||||
|
|
||||||
|
|
||||||
def load_darknet_weights(self, weights, cutoff=-1):
|
def load_darknet_weights(self, weights, cutoff=-1):
|
||||||
|
|
Loading…
Reference in New Issue