updates
This commit is contained in:
		
							parent
							
								
									058bb7f38d
								
							
						
					
					
						commit
						e2a8f5bdce
					
				
							
								
								
									
										19
									
								
								models.py
								
								
								
								
							
							
						
						
									
										19
									
								
								models.py
								
								
								
								
							|  | @ -37,18 +37,18 @@ def create_modules(module_defs): | |||
|             modules.add_module('upsample_%d' % i, upsample) | ||||
| 
 | ||||
|         elif module_def['type'] == 'route': | ||||
|             layers = [int(x) for x in module_def["layers"].split(',')] | ||||
|             layers = [int(x) for x in module_def['layers'].split(',')] | ||||
|             filters = sum([output_filters[layer_i] for layer_i in layers]) | ||||
|             modules.add_module('route_%d' % i, EmptyLayer()) | ||||
| 
 | ||||
|         elif module_def['type'] == 'shortcut': | ||||
|             filters = output_filters[int(module_def['from'])] | ||||
|             modules.add_module("shortcut_%d" % i, EmptyLayer()) | ||||
|             modules.add_module('shortcut_%d' % i, EmptyLayer()) | ||||
| 
 | ||||
|         elif module_def["type"] == "yolo": | ||||
|             anchor_idxs = [int(x) for x in module_def["mask"].split(",")] | ||||
|         elif module_def['type'] == 'yolo': | ||||
|             anchor_idxs = [int(x) for x in module_def['mask'].split(',')] | ||||
|             # Extract anchors | ||||
|             anchors = [float(x) for x in module_def["anchors"].split(",")] | ||||
|             anchors = [float(x) for x in module_def['anchors'].split(',')] | ||||
|             anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] | ||||
|             anchors = [anchors[i] for i in anchor_idxs] | ||||
|             num_classes = int(module_def['classes']) | ||||
|  | @ -72,7 +72,6 @@ class EmptyLayer(nn.Module): | |||
| 
 | ||||
| 
 | ||||
| class YOLOLayer(nn.Module): | ||||
|     # YOLO Layer 0 | ||||
| 
 | ||||
|     def __init__(self, anchors, nC, img_dim, anchor_idxs): | ||||
|         super(YOLOLayer, self).__init__() | ||||
|  | @ -104,15 +103,15 @@ class YOLOLayer(nn.Module): | |||
|     def forward(self, p, targets=None, requestPrecision=False, epoch=None): | ||||
|         FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor | ||||
| 
 | ||||
|         bs = p.shape[0] | ||||
|         nG = p.shape[2] | ||||
|         bs = p.shape[0]  # batch size | ||||
|         nG = p.shape[2]  # number of grid points | ||||
|         stride = self.img_dim / nG | ||||
| 
 | ||||
|         if p.is_cuda and not self.grid_x.is_cuda: | ||||
|             self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() | ||||
|             self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda() | ||||
| 
 | ||||
|         # x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65)  # (bs, anchors, grid, grid, classes + xywh) | ||||
|         # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80)  # (bs, anchors, grid, grid, classes + xywh) | ||||
|         p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous()  # prediction | ||||
| 
 | ||||
|         # Get outputs | ||||
|  | @ -255,7 +254,7 @@ def load_weights(self, weights_path): | |||
|     """Parses and loads the weights stored in 'weights_path'""" | ||||
| 
 | ||||
|     # Open the weights file | ||||
|     fp = open(weights_path, "rb") | ||||
|     fp = open(weights_path, 'rb') | ||||
|     header = np.fromfile(fp, dtype=np.int32, count=5)  # First five are header values | ||||
| 
 | ||||
|     # Needed to write header when saving weights | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue