diff --git a/models.py b/models.py index 6345dbf8..683c3a7f 100755 --- a/models.py +++ b/models.py @@ -143,6 +143,7 @@ class YOLOLayer(nn.Module): self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2) if ONNX_EXPORT: + self.training = False self.create_grids((img_size[1] // stride, img_size[0] // stride)) # number x, y grid points def create_grids(self, ng=(13, 13), device='cpu'):