updates
This commit is contained in:
parent
ab2ea5a2f9
commit
3cd76b2185
|
@ -226,7 +226,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
|||
th = torch.zeros(nB, nA, nG, nG)
|
||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||
TC = torch.ShortTensor(nB, max(nT)).fill_(-1) # target category
|
||||
|
||||
for b in range(nB):
|
||||
nTb = nT[b] # number of targets
|
||||
|
@ -235,7 +234,8 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
|||
t = target[b]
|
||||
|
||||
# Convert to position relative to box
|
||||
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||
gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||
|
||||
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
|
||||
gi = torch.clamp(gx.long(), min=0, max=nG - 1)
|
||||
gj = torch.clamp(gy.long(), min=0, max=nG - 1)
|
||||
|
@ -270,7 +270,6 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
|||
else:
|
||||
if iou_best < 0.10:
|
||||
continue
|
||||
i = 0
|
||||
|
||||
tc, gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||
|
||||
|
|
Loading…
Reference in New Issue