From 46e3343494d38caa183c49ce54ed3fdcd10a0e97 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 21 Feb 2019 16:16:35 +0100 Subject: [PATCH] updates --- models.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/models.py b/models.py index 7a659404..1ecfdc18 100755 --- a/models.py +++ b/models.py @@ -135,9 +135,10 @@ class YOLOLayer(nn.Module): # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction - # Width and height - wh = p[..., 2:4] # yolo method - # wh = torch.sigmoid(p[..., 2:4]) # power method + # xy, width and height + xy = torch.sigmoid(p[..., 0:2]) + wh = p[..., 2:4] # wh (yolo method) + # wh = torch.sigmoid(p[..., 2:4]) # wh (power method) # Training if targets is not None: @@ -146,7 +147,6 @@ class YOLOLayer(nn.Module): CrossEntropyLoss = nn.CrossEntropyLoss() # Get outputs - xy = torch.sigmoid(p[..., 0:2]) p_conf = p[..., 4] # Conf p_cls = p[..., 5:] # Class @@ -160,7 +160,6 @@ class YOLOLayer(nn.Module): nT = sum([len(x) for x in targets]) # number of targets nM = mask.sum().float() # number of anchors (assigned to targets) k = 1 # nM / bs - if nM > 0: lxy = k * MSELoss(xy[mask], txy[mask]) lwh = k * MSELoss(wh[mask], twh[mask]) @@ -184,14 +183,14 @@ class YOLOLayer(nn.Module): anchor_wh = self.anchor_wh.repeat((1, 1, nG, nG, 1)).view((1, -1, 2)) / nG # p = p.view(-1, 85) - # xy = torch.sigmoid(p[:, 0:2]) + self.grid_xy[0] # x, y - # wh = torch.exp(p[:, 2:4]) * self.anchor_wh[0] # width, height + # xy = xy + self.grid_xy[0] # x, y + # wh = torch.exp(wh) * self.anchor_wh[0] # width, height # p_conf = torch.sigmoid(p[:, 4:5]) # Conf # p_cls = F.softmax(p[:, 5:85], 1) * p_conf # SSD-like conf # return torch.cat((xy / nG, wh, p_conf, p_cls), 1).t() p = p.view(1, -1, 85) - xy = torch.sigmoid(p[..., 0:2]) + grid_xy # x, y + xy = xy + grid_xy # x, y wh = torch.exp(p[..., 2:4]) * anchor_wh # width, height p_conf = torch.sigmoid(p[..., 4:5]) # Conf p_cls = p[..., 5:85] @@ -202,7 +201,7 @@ class YOLOLayer(nn.Module): p_cls = p_cls.permute(2, 1, 0) return torch.cat((xy / nG, wh, p_conf, p_cls), 2).squeeze().t() - p[..., 0:2] = torch.sigmoid(p[..., 0:2]) + self.grid_xy # xy + p[..., 0:2] = xy + self.grid_xy # xy p[..., 2:4] = torch.exp(wh) * self.anchor_wh # wh yolo method # p[..., 2:4] = ((wh * 2) ** 2) * self.anchor_wh # wh power method p[..., 4] = torch.sigmoid(p[..., 4]) # p_conf