updates
This commit is contained in:
		
							parent
							
								
									4a4668224b
								
							
						
					
					
						commit
						14e4519620
					
				|  | @ -55,11 +55,13 @@ def detect( | ||||||
|         t = time.time() |         t = time.time() | ||||||
|         save_path = str(Path(output) / Path(path).name) |         save_path = str(Path(output) / Path(path).name) | ||||||
| 
 | 
 | ||||||
|  |         if ONNX_EXPORT: | ||||||
|  |             img = torch.zeros((1, 3, 416, 416)) | ||||||
|  |             torch.onnx.export(model, img, 'weights/export.onnx', verbose=True) | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|         # Get detections |         # Get detections | ||||||
|         img = torch.from_numpy(img).unsqueeze(0).to(device) |         img = torch.from_numpy(img).unsqueeze(0).to(device) | ||||||
|         if ONNX_EXPORT: |  | ||||||
|             torch.onnx.export(model, img, 'weights/model.onnx', verbose=True) |  | ||||||
|             return |  | ||||||
|         pred, _ = model(img) |         pred, _ = model(img) | ||||||
|         detections = non_max_suppression(pred, conf_thres, nms_thres)[0] |         detections = non_max_suppression(pred, conf_thres, nms_thres)[0] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										14
									
								
								models.py
								
								
								
								
							
							
						
						
									
										14
									
								
								models.py
								
								
								
								
							|  | @ -5,7 +5,7 @@ import torch.nn.functional as F | ||||||
| from utils.parse_config import * | from utils.parse_config import * | ||||||
| from utils.utils import * | from utils.utils import * | ||||||
| 
 | 
 | ||||||
| ONNX_EXPORT = False | ONNX_EXPORT = True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_modules(module_defs): | def create_modules(module_defs): | ||||||
|  | @ -234,18 +234,20 @@ def get_yolo_layers(model): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_grids(self, img_size, ng, device='cpu'): | def create_grids(self, img_size, ng, device='cpu'): | ||||||
|  |     nx, ny = ng, ng  # x and y grid size | ||||||
|     self.img_size = img_size |     self.img_size = img_size | ||||||
|     self.stride = img_size / ng |     self.stride = img_size / nx | ||||||
| 
 | 
 | ||||||
|     # build xy offsets |     # build xy offsets | ||||||
|     grid_x = torch.arange(ng).repeat((ng, 1)).view((1, 1, ng, ng)).float() |     yv, xv = torch.meshgrid([torch.arange(nx), torch.arange(ny)]) | ||||||
|     grid_y = grid_x.permute(0, 1, 3, 2) |     self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, nx, ny, 2)) | ||||||
|     self.grid_xy = torch.stack((grid_x, grid_y), 4).to(device) |  | ||||||
| 
 | 
 | ||||||
|     # build wh gains |     # build wh gains | ||||||
|     self.anchor_vec = self.anchors.to(device) / self.stride |     self.anchor_vec = self.anchors.to(device) / self.stride | ||||||
|     self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device) |     self.anchor_wh = self.anchor_vec.view(1, self.na, 1, 1, 2).to(device) | ||||||
|     self.ng = torch.FloatTensor([ng]).to(device) |     self.ng = torch.Tensor([ng]).to(device) | ||||||
|  |     self.nx = nx | ||||||
|  |     self.ny = ny | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def load_darknet_weights(self, weights, cutoff=-1): | def load_darknet_weights(self, weights, cutoff=-1): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue