diff --git a/models.py b/models.py index 4c70baa1..dd77c471 100755 --- a/models.py +++ b/models.py @@ -223,8 +223,7 @@ def create_grids(self, img_size, nG, device='cpu'): # build wh gains self.anchor_vec = self.anchors.to(device) / self.stride self.anchor_wh = self.anchor_vec.view(1, self.nA, 1, 1, 2).to(device) - self.nG = torch.Tensor([nG], device=device) - + self.nG = torch.FloatTensor([nG]).to(device) def load_darknet_weights(self, weights, cutoff=-1): # Parses and loads the weights stored in 'weights'