updates
This commit is contained in:
parent
55c6efbb39
commit
545f756090
|
@ -236,17 +236,17 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
|
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
|
||||||
"""
|
"""
|
||||||
nB = len(target) # number of images in batch
|
nB = len(target) # number of images in batch
|
||||||
nT = [len(x) for x in target]
|
|
||||||
txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
|
txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
|
||||||
twh = torch.zeros(nB, nA, nG, nG, 2)
|
twh = torch.zeros(nB, nA, nG, nG, 2)
|
||||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||||
|
|
||||||
for b in range(nB):
|
for b in range(nB):
|
||||||
nTb = nT[b] # number of targets
|
t = target[b]
|
||||||
|
nTb = len(t) # number of targets
|
||||||
if nTb == 0:
|
if nTb == 0:
|
||||||
continue
|
continue
|
||||||
t = target[b]
|
|
||||||
|
|
||||||
gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
|
gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
|
||||||
|
|
||||||
|
@ -267,14 +267,13 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
||||||
iou_order = torch.argsort(-iou_best) # best to worst
|
iou_order = torch.argsort(-iou_best) # best to worst
|
||||||
|
|
||||||
# Unique anchor selection
|
# Unique anchor selection
|
||||||
u = torch.cat((gi, gj, a), 0).view((3, -1))
|
u = torch.stack((gi, gj, a), 0)[:, iou_order]
|
||||||
# u = torch.stack((gi, gj, a),0)
|
# _, first_unique = np.unique(u, axis=1, return_index=True) # first unique indices
|
||||||
_, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices
|
first_unique = return_torch_unique_index(u, torch.unique(u, dim=1)) # torch alternative
|
||||||
# _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy?
|
|
||||||
|
|
||||||
i = iou_order[first_unique]
|
i = iou_order[first_unique]
|
||||||
# best anchor must share significant commonality (iou) with target
|
# best anchor must share significant commonality (iou) with target
|
||||||
i = i[iou_best[i] > 0.10] # TODO: arbitrary threshold is problematic
|
i = i[iou_best[i] > 0.10] # TODO: examine arbitrary threshold
|
||||||
if len(i) == 0:
|
if len(i) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -428,6 +427,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def return_torch_unique_index(u, uv):
|
||||||
|
n = uv.shape[1] # number of columns
|
||||||
|
first_unique = torch.zeros(n, device=u.device).long()
|
||||||
|
for j in range(n):
|
||||||
|
first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0]
|
||||||
|
|
||||||
|
return first_unique
|
||||||
|
|
||||||
|
|
||||||
def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
||||||
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
|
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue