This commit is contained in:
Glenn Jocher 2019-08-01 00:33:17 +02:00
parent 3b694fc8d0
commit d07b9988e3
1 changed files with 4 additions and 4 deletions

View File

@ -118,7 +118,7 @@ class YOLOLayer(nn.Module):
else:
bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1]
if (self.nx, self.ny) != (nx, ny):
create_grids(self, img_size, (nx, ny), p.device)
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
@ -242,18 +242,18 @@ def get_yolo_layers(model):
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
def create_grids(self, img_size=416, ng=(13, 13), device='cpu'):
def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32):
nx, ny = ng # x and y grid size
self.img_size = img_size
self.stride = img_size / max(ng)
# build xy offsets
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, ny, nx, 2))
self.grid_xy = torch.stack((xv, yv), 2).to(device).type(type).view((1, 1, ny, nx, 2))
# build wh gains
self.anchor_vec = self.anchors.to(device) / self.stride
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device).type(type)
self.ng = torch.Tensor(ng).to(device)
self.nx = nx
self.ny = ny