create_grids() to YOLOLayer method
This commit is contained in:
parent
207c6fcff9
commit
91f563c2a2
52
models.py
52
models.py
|
@ -8,6 +8,7 @@ ONNX_EXPORT = False
|
|||
def create_modules(module_defs, img_size):
|
||||
# Constructs module list of layer blocks from module configuration in module_defs
|
||||
|
||||
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
||||
hyperparams = module_defs.pop(0)
|
||||
output_filters = [int(hyperparams['channels'])]
|
||||
module_list = nn.ModuleList()
|
||||
|
@ -75,12 +76,13 @@ def create_modules(module_defs, img_size):
|
|||
|
||||
elif mdef['type'] == 'yolo':
|
||||
yolo_index += 1
|
||||
l = mdef['from'] if 'from' in mdef else []
|
||||
stride = [32, 16, 8, 4, 2][yolo_index] # P3-P7 stride
|
||||
modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list
|
||||
nc=mdef['classes'], # number of classes
|
||||
img_size=img_size, # (416, 416)
|
||||
yolo_index=yolo_index, # 0, 1, 2...
|
||||
layers=l) # output layers
|
||||
layers=mdef['from'] if 'from' in mdef else [], # output layers
|
||||
stride=stride)
|
||||
|
||||
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
|
||||
try:
|
||||
|
@ -110,23 +112,34 @@ def create_modules(module_defs, img_size):
|
|||
|
||||
|
||||
class YOLOLayer(nn.Module):
|
||||
def __init__(self, anchors, nc, img_size, yolo_index, layers):
|
||||
def __init__(self, anchors, nc, img_size, yolo_index, layers, stride):
|
||||
super(YOLOLayer, self).__init__()
|
||||
self.anchors = torch.Tensor(anchors)
|
||||
self.index = yolo_index # index of this layer in layers
|
||||
self.layers = layers # model output layer indices
|
||||
self.stride = stride # layer stride
|
||||
self.nl = len(layers) # number of output layers (3)
|
||||
self.na = len(anchors) # number of anchors (3)
|
||||
self.nc = nc # number of classes (80)
|
||||
self.no = nc + 5 # number of outputs (85)
|
||||
self.nx = 0 # initialize number of x gridpoints
|
||||
self.ny = 0 # initialize number of y gridpoints
|
||||
self.nx, self.ny = 0, 0 # initialize number of x, y gridpoints
|
||||
self.anchor_vec = self.anchors / self.stride
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2)
|
||||
|
||||
if ONNX_EXPORT:
|
||||
stride = [32, 16, 8][yolo_index] # stride of this layer
|
||||
nx = img_size[1] // stride # number x grid points
|
||||
ny = img_size[0] // stride # number y grid points
|
||||
create_grids(self, img_size, (nx, ny))
|
||||
self.create_grids((img_size[1] // stride, img_size[0] // stride)) # number x, y grid points
|
||||
|
||||
def create_grids(self, ng=(13, 13), device='cpu'):
|
||||
self.nx, self.ny = ng # x and y grid size
|
||||
self.ng = torch.Tensor(ng).to(device)
|
||||
|
||||
# build xy offsets
|
||||
yv, xv = torch.meshgrid([torch.arange(self.ny), torch.arange(self.nx)])
|
||||
self.grid_xy = torch.stack((xv, yv), 2).to(device).view((1, 1, self.ny, self.nx, 2))
|
||||
|
||||
if self.anchor_vec.device != device:
|
||||
self.anchor_vec = self.anchor_vec.to(device)
|
||||
self.anchor_wh = self.anchor_wh.to(device)
|
||||
|
||||
def forward(self, p, img_size, out):
|
||||
ASFF = False # https://arxiv.org/abs/1911.09516
|
||||
|
@ -135,7 +148,7 @@ class YOLOLayer(nn.Module):
|
|||
p = out[self.layers[i]]
|
||||
bs, _, ny, nx = p.shape # bs, 255, 13, 13
|
||||
if (self.nx, self.ny) != (nx, ny):
|
||||
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
|
||||
self.create_grids((nx, ny), p.device)
|
||||
|
||||
# outputs and weights
|
||||
# w = F.softmax(p[:, -n:], 1) # normalized weights
|
||||
|
@ -154,7 +167,7 @@ class YOLOLayer(nn.Module):
|
|||
else:
|
||||
bs, _, ny, nx = p.shape # bs, 255, 13, 13
|
||||
if (self.nx, self.ny) != (nx, ny):
|
||||
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
|
||||
self.create_grids((nx, ny), p.device)
|
||||
|
||||
# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.na, self.no, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
@ -273,23 +286,6 @@ def get_yolo_layers(model):
|
|||
return [i for i, x in enumerate(model.module_defs) if x['type'] == 'yolo'] # [82, 94, 106] for yolov3
|
||||
|
||||
|
||||
def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float32):
|
||||
nx, ny = ng # x and y grid size
|
||||
self.img_size = max(img_size)
|
||||
self.stride = self.img_size / max(ng)
|
||||
|
||||
# build xy offsets
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
||||
self.grid_xy = torch.stack((xv, yv), 2).to(device).type(type).view((1, 1, ny, nx, 2))
|
||||
|
||||
# build wh gains
|
||||
self.anchor_vec = self.anchors.to(device) / self.stride
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).type(type)
|
||||
self.ng = torch.Tensor(ng).to(device)
|
||||
self.nx = nx
|
||||
self.ny = ny
|
||||
|
||||
|
||||
def load_darknet_weights(self, weights, cutoff=-1):
|
||||
# Parses and loads the weights stored in 'weights'
|
||||
|
||||
|
|
Loading…
Reference in New Issue