updates
This commit is contained in:
parent
ab2ea5a2f9
commit
3cd76b2185
|
@ -226,7 +226,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
th = torch.zeros(nB, nA, nG, nG)
|
th = torch.zeros(nB, nA, nG, nG)
|
||||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||||
TC = torch.ShortTensor(nB, max(nT)).fill_(-1) # target category
|
|
||||||
|
|
||||||
for b in range(nB):
|
for b in range(nB):
|
||||||
nTb = nT[b] # number of targets
|
nTb = nT[b] # number of targets
|
||||||
|
@ -235,7 +234,8 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
t = target[b]
|
t = target[b]
|
||||||
|
|
||||||
# Convert to position relative to box
|
# Convert to position relative to box
|
||||||
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||||
|
|
||||||
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
|
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
|
||||||
gi = torch.clamp(gx.long(), min=0, max=nG - 1)
|
gi = torch.clamp(gx.long(), min=0, max=nG - 1)
|
||||||
gj = torch.clamp(gy.long(), min=0, max=nG - 1)
|
gj = torch.clamp(gy.long(), min=0, max=nG - 1)
|
||||||
|
@ -270,7 +270,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
else:
|
else:
|
||||||
if iou_best < 0.10:
|
if iou_best < 0.10:
|
||||||
continue
|
continue
|
||||||
i = 0
|
|
||||||
|
|
||||||
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue