updates
This commit is contained in:
		
							parent
							
								
									6130b70fe7
								
							
						
					
					
						commit
						204594f299
					
				
							
								
								
									
										18
									
								
								models.py
								
								
								
								
							
							
						
						
									
										18
									
								
								models.py
								
								
								
								
							|  | @ -169,7 +169,6 @@ class Mish(nn.Module):  # https://github.com/digantamisra98/Mish | ||||||
| class YOLOLayer(nn.Module): | class YOLOLayer(nn.Module): | ||||||
|     def __init__(self, anchors, nc, img_size, yolo_index, arc): |     def __init__(self, anchors, nc, img_size, yolo_index, arc): | ||||||
|         super(YOLOLayer, self).__init__() |         super(YOLOLayer, self).__init__() | ||||||
| 
 |  | ||||||
|         self.anchors = torch.Tensor(anchors) |         self.anchors = torch.Tensor(anchors) | ||||||
|         self.na = len(anchors)  # number of anchors (3) |         self.na = len(anchors)  # number of anchors (3) | ||||||
|         self.nc = nc  # number of classes (80) |         self.nc = nc  # number of classes (80) | ||||||
|  | @ -213,27 +212,12 @@ class YOLOLayer(nn.Module): | ||||||
|             return p_cls, xy * ng, wh |             return p_cls, xy * ng, wh | ||||||
| 
 | 
 | ||||||
|         else:  # inference |         else:  # inference | ||||||
|             # s = 1.5  # scale_xy  (pxy = pxy * s - (s - 1) / 2) |  | ||||||
|             io = p.clone()  # inference output |             io = p.clone()  # inference output | ||||||
|             io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy  # xy |             io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy  # xy | ||||||
|             io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh  # wh yolo method |             io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh  # wh yolo method | ||||||
|             # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh  # wh power method |  | ||||||
|             io[..., :4] *= self.stride |             io[..., :4] *= self.stride | ||||||
| 
 |  | ||||||
|             if 'default' in self.arc:  # seperate obj and cls |  | ||||||
|             torch.sigmoid_(io[..., 4:]) |             torch.sigmoid_(io[..., 4:]) | ||||||
|             elif 'BCE' in self.arc:  # unified BCE (80 classes) |             return io.view(bs, -1, self.no), p  # view [1, 3, 13, 13, 85] as [1, 507, 85] | ||||||
|                 torch.sigmoid_(io[..., 5:]) |  | ||||||
|                 io[..., 4] = 1 |  | ||||||
|             elif 'CE' in self.arc:  # unified CE (1 background + 80 classes) |  | ||||||
|                 io[..., 4:] = F.softmax(io[..., 4:], dim=4) |  | ||||||
|                 io[..., 4] = 1 |  | ||||||
| 
 |  | ||||||
|             if self.nc == 1: |  | ||||||
|                 io[..., 5] = 1  # single-class model https://github.com/ultralytics/yolov3/issues/235 |  | ||||||
| 
 |  | ||||||
|             # reshape from [1, 3, 13, 13, 85] to [1, 507, 85] |  | ||||||
|             return io.view(bs, -1, self.no), p |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Darknet(nn.Module): | class Darknet(nn.Module): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue