torch.tensor(ng, device=device)

This commit is contained in:
Glenn Jocher 2020-04-12 13:02:00 -07:00
parent efc754a794
commit 46726dad13
1 changed files with 2 additions and 2 deletions

View File

@ -138,7 +138,7 @@ class YOLOLayer(nn.Module):
self.na = len(anchors) # number of anchors (3)
self.nc = nc # number of classes (80)
self.no = nc + 5 # number of outputs (85)
self.nx, self.ny = 0, 0 # initialize number of x, y gridpoints
self.nx, self.ny, self.ng = 0, 0, 0 # initialize number of x, y gridpoints
self.anchor_vec = self.anchors / self.stride
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2)
@ -148,7 +148,7 @@ class YOLOLayer(nn.Module):
def create_grids(self, ng=(13, 13), device='cpu'):
self.nx, self.ny = ng # x and y grid size
self.ng = torch.Tensor(ng).to(device)
self.ng = torch.tensor(ng, device=device)
# build xy offsets
if not self.training: