create_grids() to YOLOLayer method

This commit is contained in:
Glenn Jocher 2020-04-02 20:23:55 -07:00
parent 91f563c2a2
commit 93055a9d58
1 changed files with 6 additions and 5 deletions

View File

@ -134,8 +134,9 @@ class YOLOLayer(nn.Module):
self.ng = torch.Tensor(ng).to(device) self.ng = torch.Tensor(ng).to(device)
# build xy offsets # build xy offsets
yv, xv = torch.meshgrid([torch.arange(self.ny), torch.arange(self.nx)]) if not self.training:
self.grid_xy = torch.stack((xv, yv), 2).to(device).view((1, 1, self.ny, self.nx, 2)) yv, xv = torch.meshgrid([torch.arange(self.ny, device=device), torch.arange(self.nx, device=device)])
self.grid = torch.stack((xv, yv), 2).view((1, 1, self.ny, self.nx, 2))
if self.anchor_vec.device != device: if self.anchor_vec.device != device:
self.anchor_vec = self.anchor_vec.to(device) self.anchor_vec = self.anchor_vec.to(device)
@ -179,11 +180,11 @@ class YOLOLayer(nn.Module):
# Avoid broadcasting for ANE operations # Avoid broadcasting for ANE operations
m = self.na * self.nx * self.ny m = self.na * self.nx * self.ny
ng = 1 / self.ng.repeat((m, 1)) ng = 1 / self.ng.repeat((m, 1))
grid_xy = self.grid_xy.repeat((1, self.na, 1, 1, 1)).view(m, 2) grid = self.grid.repeat((1, self.na, 1, 1, 1)).view(m, 2)
anchor_wh = self.anchor_wh.repeat((1, 1, self.nx, self.ny, 1)).view(m, 2) * ng anchor_wh = self.anchor_wh.repeat((1, 1, self.nx, self.ny, 1)).view(m, 2) * ng
p = p.view(m, self.no) p = p.view(m, self.no)
xy = torch.sigmoid(p[:, 0:2]) + grid_xy # x, y xy = torch.sigmoid(p[:, 0:2]) + grid # x, y
wh = torch.exp(p[:, 2:4]) * anchor_wh # width, height wh = torch.exp(p[:, 2:4]) * anchor_wh # width, height
p_cls = torch.sigmoid(p[:, 4:5]) if self.nc == 1 else \ p_cls = torch.sigmoid(p[:, 4:5]) if self.nc == 1 else \
torch.sigmoid(p[:, 5:self.no]) * torch.sigmoid(p[:, 4:5]) # conf torch.sigmoid(p[:, 5:self.no]) * torch.sigmoid(p[:, 4:5]) # conf
@ -191,7 +192,7 @@ class YOLOLayer(nn.Module):
else: # inference else: # inference
io = p.clone() # inference output io = p.clone() # inference output
io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy # xy io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid # xy
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
io[..., :4] *= self.stride io[..., :4] *= self.stride
torch.sigmoid_(io[..., 4:]) torch.sigmoid_(io[..., 4:])