diff --git a/models.py b/models.py index dc008ada..33176bb9 100755 --- a/models.py +++ b/models.py @@ -165,6 +165,8 @@ class YOLOLayer(nn.Module): io[..., 4:] = torch.sigmoid(io[..., 4:]) # p_conf, p_cls # io[..., 5:] = F.softmax(io[..., 5:], dim=4) # p_cls io[..., :4] *= self.stride + if self.nc == 1: # single-class model https://github.com/ultralytics/yolov3/issues/235 + io[..., 5] = 1 # reshape from [1, 3, 13, 13, 85] to [1, 507, 85] return io.view(bs, -1, 5 + self.nc), p