updates
This commit is contained in:
		
							parent
							
								
									1cd907c59b
								
							
						
					
					
						commit
						f908f845ae
					
				
							
								
								
									
										13
									
								
								models.py
								
								
								
								
							
							
						
						
									
										13
									
								
								models.py
								
								
								
								
							|  | @ -45,7 +45,7 @@ def create_modules(module_defs): | |||
| 
 | ||||
|         elif module_def['type'] == 'upsample': | ||||
|             # upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest')  # WARNING: deprecated | ||||
|             upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest') | ||||
|             upsample = Upsample(scale_factor=int(module_def['stride'])) | ||||
|             modules.add_module('upsample_%d' % i, upsample) | ||||
| 
 | ||||
|         elif module_def['type'] == 'route': | ||||
|  | @ -131,6 +131,7 @@ class YOLOLayer(nn.Module): | |||
|         self.loss_means = torch.ones(6) | ||||
|         self.yolo_layer = anchor_idxs[0] / nA  # 2, 1, 0 | ||||
|         self.stride = stride | ||||
|         self.nG = nG | ||||
| 
 | ||||
|         if ONNX_EXPORT:  # use fully populated and reshaped tensors | ||||
|             self.anchor_w = self.anchor_w.repeat((1, 1, nG, nG)).view(1, -1, 1) | ||||
|  | @ -142,8 +143,8 @@ class YOLOLayer(nn.Module): | |||
| 
 | ||||
|     def forward(self, p, targets=None, batch_report=False, var=None): | ||||
|         FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor | ||||
|         bs = p.shape[0]  # batch size | ||||
|         nG = p.shape[2]  # number of grid points | ||||
|         bs = 1 if ONNX_EXPORT else p.shape[0]  # batch size | ||||
|         nG = self.nG  # number of grid points | ||||
| 
 | ||||
|         if p.is_cuda and not self.weights.is_cuda: | ||||
|             self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() | ||||
|  | @ -285,6 +286,9 @@ class Darknet(nn.Module): | |||
|                 x = module(x) | ||||
|             elif module_def['type'] == 'route': | ||||
|                 layer_i = [int(x) for x in module_def['layers'].split(',')] | ||||
|                 if len(layer_i) == 1: | ||||
|                     x = layer_outputs[layer_i[0]] | ||||
|                 else: | ||||
|                     x = torch.cat([layer_outputs[i] for i in layer_i], 1) | ||||
|             elif module_def['type'] == 'shortcut': | ||||
|                 layer_i = int(module_def['from']) | ||||
|  | @ -328,7 +332,8 @@ class Darknet(nn.Module): | |||
| 
 | ||||
|         if ONNX_EXPORT: | ||||
|             # Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet) | ||||
|             output = output[0]  # first layer reshaped to 85 x 507 | ||||
|             output = output[1]  # first layer reshaped to 85 x 507 | ||||
|             # output = torch.cat(output, 1) | ||||
|             return output[5:85].t(), output[:4].t()  # ONNX scores, boxes | ||||
| 
 | ||||
|         return sum(output) if is_training else torch.cat(output, 1) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue