Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
5910353a86
|
@ -55,11 +55,13 @@ def detect(
|
|||
t = time.time()
|
||||
save_path = str(Path(output) / Path(path).name)
|
||||
|
||||
if ONNX_EXPORT:
|
||||
img = torch.zeros((1, 3, 416, 416))
|
||||
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
|
||||
return
|
||||
|
||||
# Get detections
|
||||
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
||||
if ONNX_EXPORT:
|
||||
torch.onnx.export(model, img, 'weights/model.onnx', verbose=True)
|
||||
return
|
||||
pred, _ = model(img)
|
||||
detections = non_max_suppression(pred, conf_thres, nms_thres)[0]
|
||||
|
||||
|
|
14
models.py
14
models.py
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from utils.parse_config import *
|
||||
from utils.utils import *
|
||||
|
||||
ONNX_EXPORT = False
|
||||
ONNX_EXPORT = True
|
||||
|
||||
|
||||
def create_modules(module_defs):
|
||||
|
@ -234,18 +234,20 @@ def get_yolo_layers(model):
|
|||
|
||||
|
||||
def create_grids(self, img_size, ng, device='cpu'):
|
||||
nx, ny = ng, ng # x and y grid size
|
||||
self.img_size = img_size
|
||||
self.stride = img_size / ng
|
||||
self.stride = img_size / nx
|
||||
|
||||
# build xy offsets
|
||||
grid_x = torch.arange(ng).repeat((ng, 1)).view((1, 1, ng, ng)).float()
|
||||
grid_y = grid_x.permute(0, 1, 3, 2)
|
||||
self.grid_xy = torch.stack((grid_x, grid_y), 4).to(device)
|
||||
yv, xv = torch.meshgrid([torch.arange(nx), torch.arange(ny)])
|
||||
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, nx, ny, 2))
|
||||
|
||||
# build wh gains
|
||||
self.anchor_vec = self.anchors.to(device) / self.stride
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
|
||||
self.ng = torch.FloatTensor([ng]).to(device)
|
||||
self.ng = torch.Tensor([ng]).to(device)
|
||||
self.nx = nx
|
||||
self.ny = ny
|
||||
|
||||
|
||||
def load_darknet_weights(self, weights, cutoff=-1):
|
||||
|
|
Loading…
Reference in New Issue