diff --git a/models.py b/models.py index 528ede98..cf95a959 100755 --- a/models.py +++ b/models.py @@ -162,7 +162,7 @@ class YOLOLayer(nn.Module): if ONNX_EXPORT: bs = 1 # batch size else: - bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1] + bs, _, ny, nx = p.shape # bs, 255, 13, 13 if (self.nx, self.ny) != (nx, ny): create_grids(self, img_size, (nx, ny), p.device, p.dtype)