updates
This commit is contained in:
parent
5fcdcefec3
commit
dc9f2ef6ba
|
@ -150,10 +150,13 @@ class YOLOLayer(nn.Module):
|
||||||
p_conf = p[..., 4] # Conf
|
p_conf = p[..., 4] # Conf
|
||||||
p_cls = p[..., 5:] # Class
|
p_cls = p[..., 5:] # Class
|
||||||
|
|
||||||
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
|
if p.is_cuda:
|
||||||
|
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec.cuda(), self.nA, self.nC, nG)
|
||||||
|
else:
|
||||||
|
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
|
||||||
|
|
||||||
tcls = tcls[mask]
|
tcls = tcls[mask]
|
||||||
if xy.is_cuda:
|
if p.is_cuda:
|
||||||
txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda()
|
txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda()
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
|
|
|
@ -242,8 +242,14 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||||
|
|
||||||
|
if anchor_vec.is_cuda():
|
||||||
|
txy = txy.cuda()
|
||||||
|
twh = twh.cuda()
|
||||||
|
tconf = tconf.cuda()
|
||||||
|
tcls = tcls.cuda()
|
||||||
|
|
||||||
for b in range(nB):
|
for b in range(nB):
|
||||||
t = target[b].cpu()
|
t = target[b]
|
||||||
nTb = len(t) # number of targets
|
nTb = len(t) # number of targets
|
||||||
if nTb == 0:
|
if nTb == 0:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue