This commit is contained in:
Glenn Jocher 2019-03-04 16:11:37 +01:00
parent 5fcdcefec3
commit dc9f2ef6ba
2 changed files with 12 additions and 3 deletions

View File

@ -150,10 +150,13 @@ class YOLOLayer(nn.Module):
p_conf = p[..., 4] # Conf p_conf = p[..., 4] # Conf
p_cls = p[..., 5:] # Class p_cls = p[..., 5:] # Class
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG) if p.is_cuda:
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec.cuda(), self.nA, self.nC, nG)
else:
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
tcls = tcls[mask] tcls = tcls[mask]
if xy.is_cuda: if p.is_cuda:
txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda() txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda()
# Compute losses # Compute losses

View File

@ -242,8 +242,14 @@ def build_targets(target, anchor_vec, nA, nC, nG):
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
if anchor_vec.is_cuda():
txy = txy.cuda()
twh = twh.cuda()
tconf = tconf.cuda()
tcls = tcls.cuda()
for b in range(nB): for b in range(nB):
t = target[b].cpu() t = target[b]
nTb = len(t) # number of targets nTb = len(t) # number of targets
if nTb == 0: if nTb == 0:
continue continue