grid.float()
This commit is contained in:
parent
93055a9d58
commit
41a002e798
|
@ -136,7 +136,7 @@ class YOLOLayer(nn.Module):
|
||||||
# build xy offsets
|
# build xy offsets
|
||||||
if not self.training:
|
if not self.training:
|
||||||
yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)])
|
yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)])
|
||||||
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2))
|
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2)).float()
|
||||||
|
|
||||||
if self.anchor_vec.device != device:
|
if self.anchor_vec.device != device:
|
||||||
self.anchor_vec = self.anchor_vec.to(device)
|
self.anchor_vec = self.anchor_vec.to(device)
|
||||||
|
|
Loading…
Reference in New Issue