This commit is contained in:
Glenn Jocher 2019-01-03 23:41:31 +01:00
parent b181c61f4b
commit cff2a81315
2 changed files with 32 additions and 15 deletions

View File

@ -7,7 +7,6 @@ from utils.utils import *
from utils import torch_utils
def detect(
net_config_path,
data_config_path,
@ -68,7 +67,8 @@ def detect(
with torch.no_grad():
# cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed
img = torch.from_numpy(img).unsqueeze(0).to(device)
# pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export
if ONNX_EXPORT:
pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export
pred = model(img)
pred = pred[pred[:, :, 4] > conf_thres]

View File

@ -5,6 +5,8 @@ import torch.nn as nn
from utils.parse_config import *
from utils.utils import *
ONNX_EXPORT = True
def create_modules(module_defs):
"""
@ -80,7 +82,7 @@ class EmptyLayer(nn.Module):
super(EmptyLayer, self).__init__()
class Upsample(torch.nn.Module):
class Upsample(nn.Module):
# Custom Upsample layer (nn.Upsample gives deprecated warning message)
def __init__(self, scale_factor=1, mode='nearest'):
@ -120,22 +122,30 @@ class YOLOLayer(nn.Module):
nG = int(self.img_dim / stride) # number grid points
self.grid_x = torch.arange(nG).repeat(nG, 1).view([1, 1, nG, nG]).float()
self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float()
self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float()
self.anchor_wh = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors]) # scale anchors
self.anchor_w = self.anchor_wh[:, 0:1].view((1, nA, 1, 1))
self.anchor_h = self.anchor_wh[:, 1:2].view((1, nA, 1, 1))
self.anchor_w = self.anchor_wh[:, 0].view((1, nA, 1, 1))
self.anchor_h = self.anchor_wh[:, 1].view((1, nA, 1, 1))
self.weights = class_weights()
self.loss_means = torch.ones(6)
self.tx, self.ty, self.tw, self.th = [], [], [], []
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
self.stride = stride
if ONNX_EXPORT: # use fully populated and reshaped tensors
self.anchor_w = self.anchor_w.repeat((1, 1, nG, nG)).view(1, -1, 1)
self.anchor_h = self.anchor_h.repeat((1, 1, nG, nG)).view(1, -1, 1)
self.grid_x = self.grid_x.repeat(1, nA, 1, 1).view(1, -1, 1)
self.grid_y = self.grid_y.repeat(1, nA, 1, 1).view(1, -1, 1)
self.grid_xy = torch.cat((self.grid_x, self.grid_y), 2)
self.anchor_wh = torch.cat((self.anchor_w, self.anchor_h), 2) / nG
def forward(self, p, targets=None, batch_report=False, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size
nG = p.shape[2] # number of grid points
if p.is_cuda and not self.grid_x.is_cuda:
if p.is_cuda and not self.weights.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
self.weights, self.loss_means = self.weights.cuda(), self.loss_means.cuda()
@ -239,16 +249,25 @@ class YOLOLayer(nn.Module):
nT, TP, FP, FPe, FN, TC
else:
stride = self.img_dim / nG
if ONNX_EXPORT:
p = p.view(1, -1, 85)
xy = torch.sigmoid(p[..., 0:2]) + self.grid_xy # x, y
width_height = torch.exp(p[..., 2:4]) * self.anchor_wh # width, height
p_conf = torch.sigmoid(p[..., 4:5]) # Conf
## p_cls = torch.sigmoid(p[..., 5:85]) # Class
p_cls = F.softmax(p[..., 5:85], 2) * p_conf # SSD-like conf
# p_cls = torch.exp(p[..., 5:85]) / torch.exp(p[..., 5:85]).sum(2).unsqueeze(2) #* p_conf # F.softmax() equivalent
return torch.cat((xy / nG, width_height, p_conf, p_cls), 2).squeeze().t()
p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x
p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y
p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w # width
p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h # height
p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf
p[..., :4] *= stride
p[..., :4] *= self.stride
# reshape from [1, 3, 13, 13, 85] to [1, 507, 85]
return p.view(bs, self.nA * nG * nG, 5 + self.nC)
return p.view(bs, -1, 5 + self.nC)
class Darknet(nn.Module):
@ -316,12 +335,10 @@ class Darknet(nn.Module):
self.losses['nT'] /= 3
self.losses['TC'] = 0
ONNX_export = False
if ONNX_export:
if ONNX_EXPORT:
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet)
output = output[0].squeeze().transpose(0, 1) # first layer reshaped to 85 x 507
output[5:85] = F.softmax(output[5:85], dim=0) * output[4:5] # SSD-like conf
return output[5:85], output[:4] # ONNX scores, boxes
output = output[0] # first layer reshaped to 85 x 507
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
return sum(output) if is_training else torch.cat(output, 1)