Merge remote-tracking branch 'origin/master'

This commit is contained in:
Glenn Jocher 2019-04-21 20:35:19 +02:00
commit 5910353a86
2 changed files with 13 additions and 9 deletions

View File

@ -55,11 +55,13 @@ def detect(
t = time.time()
save_path = str(Path(output) / Path(path).name)
if ONNX_EXPORT:
img = torch.zeros((1, 3, 416, 416))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device)
if ONNX_EXPORT:
torch.onnx.export(model, img, 'weights/model.onnx', verbose=True)
return
pred, _ = model(img)
detections = non_max_suppression(pred, conf_thres, nms_thres)[0]

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from utils.parse_config import *
from utils.utils import *
ONNX_EXPORT = False
ONNX_EXPORT = True
def create_modules(module_defs):
@ -234,18 +234,20 @@ def get_yolo_layers(model):
def create_grids(self, img_size, ng, device='cpu'):
nx, ny = ng, ng # x and y grid size
self.img_size = img_size
self.stride = img_size / ng
self.stride = img_size / nx
# build xy offsets
grid_x = torch.arange(ng).repeat((ng, 1)).view((1, 1, ng, ng)).float()
grid_y = grid_x.permute(0, 1, 3, 2)
self.grid_xy = torch.stack((grid_x, grid_y), 4).to(device)
yv, xv = torch.meshgrid([torch.arange(nx), torch.arange(ny)])
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, nx, ny, 2))
# build wh gains
self.anchor_vec = self.anchors.to(device) / self.stride
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
self.ng = torch.FloatTensor([ng]).to(device)
self.ng = torch.Tensor([ng]).to(device)
self.nx = nx
self.ny = ny
def load_darknet_weights(self, weights, cutoff=-1):