updates
This commit is contained in:
		
							parent
							
								
									d8370d13ea
								
							
						
					
					
						commit
						17a06dcf83
					
				
							
								
								
									
										20
									
								
								models.py
								
								
								
								
							
							
						
						
									
										20
									
								
								models.py
								
								
								
								
							|  | @ -81,18 +81,20 @@ def create_modules(module_defs, img_size, arc): | |||
| 
 | ||||
|         elif mdef['type'] == 'yolo': | ||||
|             yolo_index += 1 | ||||
|             l = mdef['from'] if 'from' in mdef else [] | ||||
|             modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']],  # anchor list | ||||
|                                 nc=mdef['classes'],  # number of classes | ||||
|                                 img_size=img_size,  # (416, 416) | ||||
|                                 yolo_index=yolo_index,  # 0, 1 or 2 | ||||
|                                 arc=arc)  # yolo architecture | ||||
|                                 yolo_index=yolo_index,  # 0, 1, 2... | ||||
|                                 layers=l)  # output layers | ||||
| 
 | ||||
|             # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) | ||||
|             try: | ||||
|                 bo = -4.5  #  obj bias | ||||
|                 bc = math.log(1 / (modules.nc - 0.99))  # cls bias: class probability is sigmoid(p) = 1/nc | ||||
| 
 | ||||
|                 bias = module_list[-1][0].bias.view(modules.na, -1)  # 255 to 3x85 | ||||
|                 j = l[yolo_index] if 'from' in mdef else -1 | ||||
|                 bias = module_list[j][0].bias.view(modules.na, -1)  # 255 to 3x85 | ||||
|                 bias[:, 4] += bo - bias[:, 4].mean()  # obj | ||||
|                 bias[:, 5:] += bc - bias[:, 5:].mean()  # cls, view with utils.print_model_biases(model) | ||||
|             except: | ||||
|  | @ -168,15 +170,17 @@ class Mish(nn.Module):  # https://github.com/digantamisra98/Mish | |||
| 
 | ||||
| 
 | ||||
| class YOLOLayer(nn.Module): | ||||
|     def __init__(self, anchors, nc, img_size, yolo_index, arc): | ||||
|     def __init__(self, anchors, nc, img_size, yolo_index, layers): | ||||
|         super(YOLOLayer, self).__init__() | ||||
|         self.anchors = torch.Tensor(anchors) | ||||
|         self.index = yolo_index  # index of this layer in layers | ||||
|         self.layers = layers  # model output layer indices | ||||
|         self.nl = len(layers)  # number of output layers (3) | ||||
|         self.na = len(anchors)  # number of anchors (3) | ||||
|         self.nc = nc  # number of classes (80) | ||||
|         self.no = nc + 5  # number of outputs | ||||
|         self.no = nc + 5  # number of outputs (85) | ||||
|         self.nx = 0  # initialize number of x gridpoints | ||||
|         self.ny = 0  # initialize number of y gridpoints | ||||
|         self.arc = arc | ||||
| 
 | ||||
|         if ONNX_EXPORT: | ||||
|             stride = [32, 16, 8][yolo_index]  # stride of this layer | ||||
|  | @ -184,7 +188,7 @@ class YOLOLayer(nn.Module): | |||
|             ny = img_size[0] // stride  # number y grid points | ||||
|             create_grids(self, img_size, (nx, ny)) | ||||
| 
 | ||||
|     def forward(self, p, img_size): | ||||
|     def forward(self, p, img_size, out): | ||||
|         if ONNX_EXPORT: | ||||
|             bs = 1  # batch size | ||||
|         else: | ||||
|  | @ -268,7 +272,7 @@ class Darknet(nn.Module): | |||
|                         x = torch.cat([out[i] for i in layers], 1) | ||||
|                     # print(''), [print(out[i].shape) for i in layers], print(x.shape) | ||||
|             elif mtype == 'yolo': | ||||
|                 yolo_out.append(module(x, img_size)) | ||||
|                 yolo_out.append(module(x, img_size, out)) | ||||
|             out.append(x if i in self.routs else []) | ||||
|             if verbose: | ||||
|                 print('%g/%g %s -' % (i, len(self.module_list), mtype), list(x.shape), str) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue