updates
This commit is contained in:
		
							parent
							
								
									5fcdcefec3
								
							
						
					
					
						commit
						dc9f2ef6ba
					
				|  | @ -150,10 +150,13 @@ class YOLOLayer(nn.Module): | |||
|             p_conf = p[..., 4]  # Conf | ||||
|             p_cls = p[..., 5:]  # Class | ||||
| 
 | ||||
|             if p.is_cuda: | ||||
|                 txy, twh, mask, tcls = build_targets(targets, self.anchor_vec.cuda(), self.nA, self.nC, nG) | ||||
|             else: | ||||
|                 txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) | ||||
| 
 | ||||
|             tcls = tcls[mask] | ||||
|             if xy.is_cuda: | ||||
|             if p.is_cuda: | ||||
|                 txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda() | ||||
| 
 | ||||
|             # Compute losses | ||||
|  |  | |||
|  | @ -242,8 +242,14 @@ def build_targets(target, anchor_vec, nA, nC, nG): | |||
|     tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) | ||||
|     tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0)  # nC = number of classes | ||||
| 
 | ||||
|     if anchor_vec.is_cuda(): | ||||
|         txy = txy.cuda() | ||||
|         twh = twh.cuda() | ||||
|         tconf = tconf.cuda() | ||||
|         tcls = tcls.cuda() | ||||
| 
 | ||||
|     for b in range(nB): | ||||
|         t = target[b].cpu() | ||||
|         t = target[b] | ||||
|         nTb = len(t)  # number of targets | ||||
|         if nTb == 0: | ||||
|             continue | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue