updates
This commit is contained in:
parent
c89982d134
commit
324f860235
13
models.py
13
models.py
|
@ -114,16 +114,17 @@ class YOLOLayer(nn.Module):
|
||||||
if cfg.endswith('yolov3-tiny.cfg'):
|
if cfg.endswith('yolov3-tiny.cfg'):
|
||||||
stride *= 2
|
stride *= 2
|
||||||
|
|
||||||
ng = (int(img_size[0] / stride), int(img_size[1] / stride)) # number grid points
|
nx = int(img_size[1] / stride) # number x grid points
|
||||||
create_grids(self, max(img_size), ng)
|
ny = int(img_size[0] / stride) # number y grid points
|
||||||
|
create_grids(self, max(img_size), (nx, ny))
|
||||||
|
|
||||||
def forward(self, p, img_size, var=None):
|
def forward(self, p, img_size, var=None):
|
||||||
if ONNX_EXPORT:
|
if ONNX_EXPORT:
|
||||||
bs = 1 # batch size
|
bs = 1 # batch size
|
||||||
else:
|
else:
|
||||||
bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1]
|
bs, ny, nx = p.shape[0], p.shape[-2], p.shape[-1]
|
||||||
if (self.ny, self.nx) != (ny, nx):
|
if (self.nx, self.ny) != (nx, ny):
|
||||||
create_grids(self, img_size, (ny, nx), p.device)
|
create_grids(self, img_size, (nx, ny), p.device)
|
||||||
|
|
||||||
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
|
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
|
||||||
p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
p = p.view(bs, self.na, self.nc + 5, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||||
|
@ -238,8 +239,8 @@ def get_yolo_layers(model):
|
||||||
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
||||||
|
|
||||||
|
|
||||||
def create_grids(self, img_size, ng, device='cpu'):
|
def create_grids(self, img_size=416, ng=(13, 13), device='cpu'):
|
||||||
ny, nx = ng # x and y grid size
|
nx, ny = ng # x and y grid size
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.stride = img_size / max(ng)
|
self.stride = img_size / max(ng)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue