From 8de043980ae855dbef78e6bf6b859d05869a8088 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 18 Feb 2019 16:25:57 +0100 Subject: [PATCH] updates --- models.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/models.py b/models.py index 4ae9cd2a..535e8ece 100755 --- a/models.py +++ b/models.py @@ -6,7 +6,7 @@ import torch.nn as nn from utils.parse_config import * from utils.utils import * -ONNX_EXPORT = False +ONNX_EXPORT = True def create_modules(module_defs): @@ -212,26 +212,24 @@ class YOLOLayer(nn.Module): else: if ONNX_EXPORT: - p = p.view(-1, 85) - xy = torch.sigmoid(p[:, 0:2]) + self.grid_xy[0] # x, y - wh = torch.exp(p[:, 2:4]) * self.anchor_wh[0] # width, height - p_conf = torch.sigmoid(p[:, 4:5]) # Conf - p_cls = F.softmax(p[:, 5:85], 1) * p_conf # SSD-like conf - return torch.cat((xy / nG, wh, p_conf, p_cls), 1) + # p = p.view(-1, 85) + # xy = torch.sigmoid(p[:, 0:2]) + self.grid_xy[0] # x, y + # wh = torch.exp(p[:, 2:4]) * self.anchor_wh[0] # width, height + # p_conf = torch.sigmoid(p[:, 4:5]) # Conf + # p_cls = F.softmax(p[:, 5:85], 1) * p_conf # SSD-like conf + # return torch.cat((xy / nG, wh, p_conf, p_cls), 1).t() - # p = p.view(1, -1, 85) - # xy = torch.sigmoid(p[..., 0:2]) + self.grid_xy # x, y - # wh = torch.exp(p[..., 2:4]) * self.anchor_wh # width, height - # p_conf = torch.sigmoid(p[..., 4:5]) # Conf - # p_cls = p[..., 5:85] - # - # # Broadcasting only supported on first dimension in CoreML. See onnx-coreml/_operators.py - # # p_cls = F.softmax(p_cls, 2) * p_conf # SSD-like conf - # p_cls = torch.exp(p_cls).permute((2, 1, 0)) - # p_cls = p_cls / p_cls.sum(0).unsqueeze(0) * p_conf.permute((2, 1, 0)) # F.softmax() equivalent - # p_cls = p_cls.permute(2, 1, 0) - # - # return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t() + p = p.view(1, -1, 85) + xy = torch.sigmoid(p[..., 0:2]) + self.grid_xy # x, y + wh = torch.exp(p[..., 2:4]) * self.anchor_wh # width, height + p_conf = torch.sigmoid(p[..., 4:5]) # Conf + p_cls = p[..., 5:85] + # Broadcasting only supported on first dimension in CoreML. See onnx-coreml/_operators.py + # p_cls = F.softmax(p_cls, 2) * p_conf # SSD-like conf + p_cls = torch.exp(p_cls).permute((2, 1, 0)) + p_cls = p_cls / p_cls.sum(0).unsqueeze(0) * p_conf.permute((2, 1, 0)) # F.softmax() equivalent + p_cls = p_cls.permute(2, 1, 0) + return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t() p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x # x p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y # y @@ -292,8 +290,8 @@ class Darknet(nn.Module): self.losses['nT'] /= 3 if ONNX_EXPORT: - output = torch.cat(output, 0) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647 - return output[:, 5:85], output[:, :4] # ONNX scores, boxes + output = torch.cat(output, 1) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647 + return output[5:85].t(), output[:4].t() # ONNX scores, boxes return sum(output) if is_training else torch.cat(output, 1)