From d07b9988e34e1fa1a2bda221abae92fabcbbc0e2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Aug 2019 00:33:17 +0200 Subject: [PATCH] updates --- models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models.py b/models.py index 0727be6d..24c53969 100755 --- a/models.py +++ b/models.py @@ -118,7 +118,7 @@ class YOLOLayer(nn.Module): else: bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1] if (self.nx, self.ny) != (nx, ny): - create_grids(self, img_size, (nx, ny), p.device) + create_grids(self, img_size, (nx, ny), p.device, p.dtype) # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh) p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction @@ -242,18 +242,18 @@ def get_yolo_layers(model): return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3 -def create_grids(self, img_size=416, ng=(13, 13), device='cpu'): +def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32): nx, ny = ng # x and y grid size self.img_size = img_size self.stride = img_size / max(ng) # build xy offsets yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) - self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, ny, nx, 2)) + self.grid_xy = torch.stack((xv, yv), 2).to(device).type(type).view((1, 1, ny, nx, 2)) # build wh gains self.anchor_vec = self.anchors.to(device) / self.stride - self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device) + self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device).type(type) self.ng = torch.Tensor(ng).to(device) self.nx = nx self.ny = ny