grid.float()

This commit is contained in:
Glenn Jocher 2020-04-03 12:38:08 -07:00
parent 93055a9d58
commit 41a002e798
1 changed files with 1 additions and 1 deletions

View File

@ -136,7 +136,7 @@ class YOLOLayer(nn.Module):
# build xy offsets # build xy offsets
if not self.training: if not self.training:
yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)]) yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)])
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2)) self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2)).float()
if self.anchor_vec.device != device: if self.anchor_vec.device != device:
self.anchor_vec = self.anchor_vec.to(device) self.anchor_vec = self.anchor_vec.to(device)