diff --git a/models.py b/models.py index a2db26e9..3d3acad3 100755 --- a/models.py +++ b/models.py @@ -190,7 +190,7 @@ class YOLOLayer(nn.Module): # return torch.cat((xy / nG, wh, p_conf, p_cls), 1).t() p = p.view(1, -1, 85) - xy = xy + grid_xy # x, y + xy = xy.view(bs, self.nA * nG * nG, 2) + grid_xy # x, y wh = torch.exp(p[..., 2:4]) * anchor_wh # width, height p_conf = torch.sigmoid(p[..., 4:5]) # Conf p_cls = p[..., 5:85]