From 636c1cff7a91c0b54c996ef48b36274b08e4a8b8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 11 Aug 2019 15:17:40 +0200 Subject: [PATCH] updates --- models.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/models.py b/models.py index a453993e..94f357b3 100755 --- a/models.py +++ b/models.py @@ -4,7 +4,7 @@ from pathlib import Path import torch.nn.functional as F -ONNX_EXPORT = False +ONNX_EXPORT = True def create_modules(module_defs): @@ -34,7 +34,7 @@ def create_modules(module_defs): modules.add_module('BatchNorm2d', nn.BatchNorm2d(filters, momentum=0.1)) if mdef['activation'] == 'leaky': # TODO: activation study https://github.com/ultralytics/yolov3/issues/441 modules.add_module('activation', nn.LeakyReLU(0.1, inplace=True)) - # modules.add_module('activation', nn.PReLU(num_parameters=1)) + # modules.add_module('activation', nn.PReLU(num_parameters=1, init=0.10)) # modules.add_module('activation', Swish()) elif mdef['type'] == 'maxpool': @@ -105,7 +105,7 @@ class YOLOLayer(nn.Module): stride = [32, 16, 8][yolo_index] # stride of this layer nx = int(img_size[1] / stride) # number x grid points ny = int(img_size[0] / stride) # number y grid points - create_grids(self, max(img_size), (nx, ny)) + create_grids(self, img_size, (nx, ny)) def forward(self, p, img_size, var=None): if ONNX_EXPORT: @@ -127,33 +127,33 @@ class YOLOLayer(nn.Module): grid_xy = self.grid_xy.repeat((1, self.na, 1, 1, 1)).view((1, -1, 2)) anchor_wh = self.anchor_wh.repeat((1, 1, self.nx, self.ny, 1)).view((1, -1, 2)) / ngu - # p = p.view(-1, 5 + self.nc) - # xy = torch.sigmoid(p[..., 0:2]) + grid_xy[0] # x, y - # wh = torch.exp(p[..., 2:4]) * anchor_wh[0] # width, height - # p_conf = torch.sigmoid(p[:, 4:5]) # Conf - # p_cls = F.softmax(p[:, 5:85], 1) * p_conf # SSD-like conf - # return torch.cat((xy / ngu[0], wh, p_conf, p_cls), 1).t() + p = p.view(-1, 5 + self.nc) + xy = torch.sigmoid(p[..., 0:2]) + grid_xy[0] # x, y + wh = torch.exp(p[..., 2:4]) * anchor_wh[0] # width, height + p_conf = torch.sigmoid(p[:, 4:5]) # Conf + p_cls = F.softmax(p[:, 5:85], 1) * p_conf # SSD-like conf + return torch.cat((xy / ngu[0], wh, p_conf, p_cls), 1).t() - p = p.view(1, -1, 5 + self.nc) - xy = torch.sigmoid(p[..., 0:2]) + grid_xy # x, y - wh = torch.exp(p[..., 2:4]) * anchor_wh # width, height - p_conf = torch.sigmoid(p[..., 4:5]) # Conf - p_cls = p[..., 5:5 + self.nc] - # Broadcasting only supported on first dimension in CoreML. See onnx-coreml/_operators.py - # p_cls = F.softmax(p_cls, 2) * p_conf # SSD-like conf - p_cls = torch.exp(p_cls).permute((2, 1, 0)) - p_cls = p_cls / p_cls.sum(0).unsqueeze(0) * p_conf.permute((2, 1, 0)) # F.softmax() equivalent - p_cls = p_cls.permute(2, 1, 0) - return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t() + # p = p.view(1, -1, 5 + self.nc) + # xy = torch.sigmoid(p[..., 0:2]) + grid_xy # x, y + # wh = torch.exp(p[..., 2:4]) * anchor_wh # width, height + # p_conf = torch.sigmoid(p[..., 4:5]) # Conf + # p_cls = p[..., 5:5 + self.nc] + # # Broadcasting only supported on first dimension in CoreML. See onnx-coreml/_operators.py + # # p_cls = F.softmax(p_cls, 2) * p_conf # SSD-like conf + # p_cls = torch.exp(p_cls).permute((2, 1, 0)) + # p_cls = p_cls / p_cls.sum(0).unsqueeze(0) * p_conf.permute((2, 1, 0)) # F.softmax() equivalent + # p_cls = p_cls.permute(2, 1, 0) + # return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t() else: # inference # s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2) io = p.clone() # inference output io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method + # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method io[..., :4] *= self.stride - # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method io[..., 4:] = torch.sigmoid(io[..., 4:]) # p_conf, p_cls # io[..., 5:] = F.softmax(io[..., 5:], dim=4) # p_cls @@ -181,7 +181,7 @@ class Darknet(nn.Module): self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training def forward(self, x, var=None): - img_size = max(x.shape[-2:]) + img_size = x.shape[-2:] layer_outputs = [] output = [] @@ -241,8 +241,8 @@ def get_yolo_layers(model): def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32): nx, ny = ng # x and y grid size - self.img_size = img_size - self.stride = img_size / max(ng) + self.img_size = max(img_size) + self.stride = self.img_size / max(ng) # build xy offsets yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])