grid.float()
This commit is contained in:
parent
93055a9d58
commit
41a002e798
|
@ -136,7 +136,7 @@ class YOLOLayer(nn.Module):
|
|||
# build xy offsets
|
||||
if not self.training:
|
||||
yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)])
|
||||
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2))
|
||||
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2)).float()
|
||||
|
||||
if self.anchor_vec.device != device:
|
||||
self.anchor_vec = self.anchor_vec.to(device)
|
||||
|
|
Loading…
Reference in New Issue