updates
This commit is contained in:
parent
3b694fc8d0
commit
d07b9988e3
|
@ -118,7 +118,7 @@ class YOLOLayer(nn.Module):
|
|||
else:
|
||||
bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1]
|
||||
if (self.nx, self.ny) != (nx, ny):
|
||||
create_grids(self, img_size, (nx, ny), p.device)
|
||||
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
|
||||
|
||||
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
@ -242,18 +242,18 @@ def get_yolo_layers(model):
|
|||
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
||||
|
||||
|
||||
def create_grids(self, img_size=416, ng=(13, 13), device='cpu'):
|
||||
def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32):
|
||||
nx, ny = ng # x and y grid size
|
||||
self.img_size = img_size
|
||||
self.stride = img_size / max(ng)
|
||||
|
||||
# build xy offsets
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
||||
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, ny, nx, 2))
|
||||
self.grid_xy = torch.stack((xv, yv), 2).to(device).type(type).view((1, 1, ny, nx, 2))
|
||||
|
||||
# build wh gains
|
||||
self.anchor_vec = self.anchors.to(device) / self.stride
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device)
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device).type(type)
|
||||
self.ng = torch.Tensor(ng).to(device)
|
||||
self.nx = nx
|
||||
self.ny = ny
|
||||
|
|
Loading…
Reference in New Issue