This commit is contained in:
Glenn Jocher 2020-03-09 18:55:17 -07:00
parent d8370d13ea
commit 17a06dcf83
1 changed files with 12 additions and 8 deletions

View File

@ -81,18 +81,20 @@ def create_modules(module_defs, img_size, arc):
elif mdef['type'] == 'yolo': elif mdef['type'] == 'yolo':
yolo_index += 1 yolo_index += 1
l = mdef['from'] if 'from' in mdef else []
modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list
nc=mdef['classes'], # number of classes nc=mdef['classes'], # number of classes
img_size=img_size, # (416, 416) img_size=img_size, # (416, 416)
yolo_index=yolo_index, # 0, 1 or 2 yolo_index=yolo_index, # 0, 1, 2...
arc=arc) # yolo architecture layers=l) # output layers
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
try: try:
bo = -4.5 #  obj bias bo = -4.5 #  obj bias
bc = math.log(1 / (modules.nc - 0.99)) # cls bias: class probability is sigmoid(p) = 1/nc bc = math.log(1 / (modules.nc - 0.99)) # cls bias: class probability is sigmoid(p) = 1/nc
bias = module_list[-1][0].bias.view(modules.na, -1) # 255 to 3x85 j = l[yolo_index] if 'from' in mdef else -1
bias = module_list[j][0].bias.view(modules.na, -1) # 255 to 3x85
bias[:, 4] += bo - bias[:, 4].mean() # obj bias[:, 4] += bo - bias[:, 4].mean() # obj
bias[:, 5:] += bc - bias[:, 5:].mean() # cls, view with utils.print_model_biases(model) bias[:, 5:] += bc - bias[:, 5:].mean() # cls, view with utils.print_model_biases(model)
except: except:
@ -168,15 +170,17 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_index, arc): def __init__(self, anchors, nc, img_size, yolo_index, layers):
super(YOLOLayer, self).__init__() super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors) self.anchors = torch.Tensor(anchors)
self.index = yolo_index # index of this layer in layers
self.layers = layers # model output layer indices
self.nl = len(layers) # number of output layers (3)
self.na = len(anchors) # number of anchors (3) self.na = len(anchors) # number of anchors (3)
self.nc = nc # number of classes (80) self.nc = nc # number of classes (80)
self.no = nc + 5 # number of outputs self.no = nc + 5 # number of outputs (85)
self.nx = 0 # initialize number of x gridpoints self.nx = 0 # initialize number of x gridpoints
self.ny = 0 # initialize number of y gridpoints self.ny = 0 # initialize number of y gridpoints
self.arc = arc
if ONNX_EXPORT: if ONNX_EXPORT:
stride = [32, 16, 8][yolo_index] # stride of this layer stride = [32, 16, 8][yolo_index] # stride of this layer
@ -184,7 +188,7 @@ class YOLOLayer(nn.Module):
ny = img_size[0] // stride # number y grid points ny = img_size[0] // stride # number y grid points
create_grids(self, img_size, (nx, ny)) create_grids(self, img_size, (nx, ny))
def forward(self, p, img_size): def forward(self, p, img_size, out):
if ONNX_EXPORT: if ONNX_EXPORT:
bs = 1 # batch size bs = 1 # batch size
else: else:
@ -268,7 +272,7 @@ class Darknet(nn.Module):
x = torch.cat([out[i] for i in layers], 1) x = torch.cat([out[i] for i in layers], 1)
# print(''), [print(out[i].shape) for i in layers], print(x.shape) # print(''), [print(out[i].shape) for i in layers], print(x.shape)
elif mtype == 'yolo': elif mtype == 'yolo':
yolo_out.append(module(x, img_size)) yolo_out.append(module(x, img_size, out))
out.append(x if i in self.routs else []) out.append(x if i in self.routs else [])
if verbose: if verbose:
print('%g/%g %s -' % (i, len(self.module_list), mtype), list(x.shape), str) print('%g/%g %s -' % (i, len(self.module_list), mtype), list(x.shape), str)