ONNX compatibility updates
This commit is contained in:
		
							parent
							
								
									febc55d96a
								
							
						
					
					
						commit
						647e1c6f52
					
				
							
								
								
									
										54
									
								
								models.py
								
								
								
								
							
							
						
						
									
										54
									
								
								models.py
								
								
								
								
							|  | @ -127,13 +127,12 @@ class YOLOLayer(nn.Module): | ||||||
| 
 | 
 | ||||||
|         self.loss_means = torch.ones(6) |         self.loss_means = torch.ones(6) | ||||||
|         self.tx, self.ty, self.tw, self.th = [], [], [], [] |         self.tx, self.ty, self.tw, self.th = [], [], [], [] | ||||||
|  |         self.yolo_layer = anchor_idxs[0] / nA  # 2, 1, 0 | ||||||
| 
 | 
 | ||||||
|     def forward(self, p, targets=None, batch_report=False, var=None): |     def forward(self, p, targets=None, batch_report=False, var=None): | ||||||
|         FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor |         FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor | ||||||
| 
 |  | ||||||
|         bs = p.shape[0]  # batch size |         bs = p.shape[0]  # batch size | ||||||
|         nG = p.shape[2]  # number of grid points |         nG = p.shape[2]  # number of grid points | ||||||
|         stride = self.img_dim / nG |  | ||||||
| 
 | 
 | ||||||
|         if p.is_cuda and not self.grid_x.is_cuda: |         if p.is_cuda and not self.grid_x.is_cuda: | ||||||
|             self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() |             self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() | ||||||
|  | @ -143,30 +142,30 @@ class YOLOLayer(nn.Module): | ||||||
|         # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80)  # (bs, anchors, grid, grid, classes + xywh) |         # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80)  # (bs, anchors, grid, grid, classes + xywh) | ||||||
|         p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous()  # prediction |         p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous()  # prediction | ||||||
| 
 | 
 | ||||||
|         # Get outputs |  | ||||||
|         x = torch.sigmoid(p[..., 0])  # Center x |  | ||||||
|         y = torch.sigmoid(p[..., 1])  # Center y |  | ||||||
|         p_conf = p[..., 4]  # Conf |  | ||||||
|         p_cls = p[..., 5:]  # Class |  | ||||||
| 
 |  | ||||||
|         # Width and height (yolo method) |  | ||||||
|         w = p[..., 2]  # Width |  | ||||||
|         h = p[..., 3]  # Height |  | ||||||
|         width = torch.exp(w.data) * self.anchor_w |  | ||||||
|         height = torch.exp(h.data) * self.anchor_h |  | ||||||
| 
 |  | ||||||
|         # Width and height (power method) |  | ||||||
|         # w = torch.sigmoid(p[..., 2])  # Width |  | ||||||
|         # h = torch.sigmoid(p[..., 3])  # Height |  | ||||||
|         # width = ((w.data * 2) ** 2) * self.anchor_w |  | ||||||
|         # height = ((h.data * 2) ** 2) * self.anchor_h |  | ||||||
| 
 |  | ||||||
|         # Training |         # Training | ||||||
|         if targets is not None: |         if targets is not None: | ||||||
|             MSELoss = nn.MSELoss() |             MSELoss = nn.MSELoss() | ||||||
|             BCEWithLogitsLoss = nn.BCEWithLogitsLoss() |             BCEWithLogitsLoss = nn.BCEWithLogitsLoss() | ||||||
|             CrossEntropyLoss = nn.CrossEntropyLoss() |             CrossEntropyLoss = nn.CrossEntropyLoss() | ||||||
| 
 | 
 | ||||||
|  |             # Get outputs | ||||||
|  |             x = torch.sigmoid(p[..., 0])  # Center x | ||||||
|  |             y = torch.sigmoid(p[..., 1])  # Center y | ||||||
|  |             p_conf = p[..., 4]  # Conf | ||||||
|  |             p_cls = p[..., 5:]  # Class | ||||||
|  | 
 | ||||||
|  |             # Width and height (yolo method) | ||||||
|  |             w = p[..., 2]  # Width | ||||||
|  |             h = p[..., 3]  # Height | ||||||
|  |             width = torch.exp(w.data) * self.anchor_w | ||||||
|  |             height = torch.exp(h.data) * self.anchor_h | ||||||
|  | 
 | ||||||
|  |             # Width and height (power method) | ||||||
|  |             # w = torch.sigmoid(p[..., 2])  # Width | ||||||
|  |             # h = torch.sigmoid(p[..., 3])  # Height | ||||||
|  |             # width = ((w.data * 2) ** 2) * self.anchor_w | ||||||
|  |             # height = ((h.data * 2) ** 2) * self.anchor_h | ||||||
|  | 
 | ||||||
|             p_boxes = None |             p_boxes = None | ||||||
|             if batch_report: |             if batch_report: | ||||||
|                 # Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13) |                 # Predictd boxes: add offset and scale with anchors (in grid space, i.e. 0-13) | ||||||
|  | @ -239,14 +238,15 @@ class YOLOLayer(nn.Module): | ||||||
|                    nT, TP, FP, FPe, FN, TC |                    nT, TP, FP, FPe, FN, TC | ||||||
| 
 | 
 | ||||||
|         else: |         else: | ||||||
|             # If not in training phase return predictions |             stride = self.img_dim / nG | ||||||
|             p_boxes = torch.stack((x + self.grid_x, y + self.grid_y, width, height), 4)  # xywh |             p[..., 0] = torch.sigmoid(p[..., 0]) + self.grid_x  # x | ||||||
|  |             p[..., 1] = torch.sigmoid(p[..., 1]) + self.grid_y  # y | ||||||
|  |             p[..., 2] = torch.exp(p[..., 2]) * self.anchor_w  # width | ||||||
|  |             p[..., 3] = torch.exp(p[..., 3]) * self.anchor_h  # height | ||||||
|  |             p[..., 4] = torch.sigmoid(p[..., 4])  # p_conf | ||||||
|  |             p[..., :4] *= stride | ||||||
| 
 | 
 | ||||||
|             # output.shape = [1, 3, 13, 13, 85] |             return p.view(bs, self.nA * nG * nG, 5 + self.nC) | ||||||
|             output = torch.cat((p_boxes * stride, torch.sigmoid(p_conf).unsqueeze(4), p_cls), 4) |  | ||||||
| 
 |  | ||||||
|             # returns shape = [1, 507, 85] |  | ||||||
|             return output.data.view(bs, -1, 5 + self.nC) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Darknet(nn.Module): | class Darknet(nn.Module): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue