This commit is contained in:
Glenn Jocher 2019-02-10 23:27:31 +01:00
parent ab2ea5a2f9
commit 3cd76b2185
1 changed files with 2 additions and 3 deletions

View File

@ -226,7 +226,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
th = torch.zeros(nB, nA, nG, nG) th = torch.zeros(nB, nA, nG, nG)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
TC = torch.ShortTensor(nB, max(nT)).fill_(-1) # target category
for b in range(nB): for b in range(nB):
nTb = nT[b] # number of targets nTb = nT[b] # number of targets
@ -235,7 +234,8 @@ def build_targets(target, anchor_wh, nA, nC, nG):
t = target[b] t = target[b]
# Convert to position relative to box # Convert to position relative to box
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors) # Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
gi = torch.clamp(gx.long(), min=0, max=nG - 1) gi = torch.clamp(gx.long(), min=0, max=nG - 1)
gj = torch.clamp(gy.long(), min=0, max=nG - 1) gj = torch.clamp(gy.long(), min=0, max=nG - 1)
@ -270,7 +270,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
else: else:
if iou_best < 0.10: if iou_best < 0.10:
continue continue
i = 0
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG