updates
This commit is contained in:
		
							parent
							
								
									f07dd72a09
								
							
						
					
					
						commit
						a116dd36f7
					
				|  | @ -132,7 +132,7 @@ class YOLOLayer(nn.Module): | |||
| 
 | ||||
|                 if p.is_cuda: | ||||
|                     self.grid_xy = self.grid_xy.cuda() | ||||
|                     self.anchor_vector = self.anchor_vector.cuda() | ||||
|                     self.anchor_vec = self.anchor_vec.cuda() | ||||
| 
 | ||||
|         # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80)  # (bs, anchors, grid, grid, classes + xywh) | ||||
|         p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous()  # prediction | ||||
|  | @ -161,7 +161,7 @@ class YOLOLayer(nn.Module): | |||
|             # width = ((w.data * 2) ** 2) * self.anchor_w | ||||
|             # height = ((h.data * 2) ** 2) * self.anchor_h | ||||
| 
 | ||||
|             tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_vector, self.nA, self.nC, nG) | ||||
|             tx, ty, tw, th, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) | ||||
| 
 | ||||
|             tcls = tcls[mask] | ||||
|             if x.is_cuda: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue