torch.tensor(ng, device=device)
This commit is contained in:
parent
efc754a794
commit
46726dad13
|
@ -138,7 +138,7 @@ class YOLOLayer(nn.Module):
|
||||||
self.na = len(anchors) # number of anchors (3)
|
self.na = len(anchors) # number of anchors (3)
|
||||||
self.nc = nc # number of classes (80)
|
self.nc = nc # number of classes (80)
|
||||||
self.no = nc + 5 # number of outputs (85)
|
self.no = nc + 5 # number of outputs (85)
|
||||||
self.nx, self.ny = 0, 0 # initialize number of x, y gridpoints
|
self.nx, self.ny, self.ng = 0, 0, 0 # initialize number of x, y gridpoints
|
||||||
self.anchor_vec = self.anchors / self.stride
|
self.anchor_vec = self.anchors / self.stride
|
||||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2)
|
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2)
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
def create_grids(self, ng=(13, 13), device='cpu'):
|
def create_grids(self, ng=(13, 13), device='cpu'):
|
||||||
self.nx, self.ny = ng # x and y grid size
|
self.nx, self.ny = ng # x and y grid size
|
||||||
self.ng = torch.Tensor(ng).to(device)
|
self.ng = torch.tensor(ng, device=device)
|
||||||
|
|
||||||
# build xy offsets
|
# build xy offsets
|
||||||
if not self.training:
|
if not self.training:
|
||||||
|
|
Loading…
Reference in New Issue