diff --git a/models.py b/models.py index cf5435b5..4a57ef59 100755 --- a/models.py +++ b/models.py @@ -179,7 +179,8 @@ class YOLOLayer(nn.Module): p = p.view(m, self.no) xy = torch.sigmoid(p[:, 0:2]) + grid_xy # x, y wh = torch.exp(p[:, 2:4]) * anchor_wh # width, height - p_cls = torch.sigmoid(p[:, 5:self.no]) * torch.sigmoid(p[:, 4:5]) # conf + p_cls = torch.sigmoid(p[:, 4:5]) if self.nc == 1 else \ + torch.sigmoid(p[:, 5:self.no]) * torch.sigmoid(p[:, 4:5]) # conf return p_cls, xy / self.ng, wh else: # inference