This commit is contained in:
Glenn Jocher 2019-04-25 20:50:37 +02:00
parent c89982d134
commit 324f860235
1 changed files with 7 additions and 6 deletions

View File

@ -114,16 +114,17 @@ class YOLOLayer(nn.Module):
if cfg.endswith('yolov3-tiny.cfg'):
stride *= 2
ng = (int(img_size[0] / stride), int(img_size[1] / stride)) # number grid points
create_grids(self, max(img_size), ng)
nx = int(img_size[1] / stride) # number x grid points
ny = int(img_size[0] / stride) # number y grid points
create_grids(self, max(img_size), (nx, ny))
def forward(self, p, img_size, var=None):
if ONNX_EXPORT:
bs = 1 # batch size
else:
bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1]
if (self.ny, self.nx) != (ny, nx):
create_grids(self, img_size, (ny, nx), p.device)
if (self.nx, self.ny) != (nx, ny):
create_grids(self, img_size, (nx, ny), p.device)
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
@ -238,8 +239,8 @@ def get_yolo_layers(model):
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
def create_grids(self, img_size, ng, device='cpu'):
ny, nx = ng # x and y grid size
def create_grids(self, img_size=416, ng=(13, 13), device='cpu'):
nx, ny = ng # x and y grid size
self.img_size = img_size
self.stride = img_size / max(ng)