updates
This commit is contained in:
parent
55c6efbb39
commit
545f756090
|
@ -236,17 +236,17 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
|||
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
|
||||
"""
|
||||
nB = len(target) # number of images in batch
|
||||
nT = [len(x) for x in target]
|
||||
|
||||
txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
|
||||
twh = torch.zeros(nB, nA, nG, nG, 2)
|
||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||
|
||||
for b in range(nB):
|
||||
nTb = nT[b] # number of targets
|
||||
t = target[b]
|
||||
nTb = len(t) # number of targets
|
||||
if nTb == 0:
|
||||
continue
|
||||
t = target[b]
|
||||
|
||||
gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
|
||||
|
||||
|
@ -267,14 +267,13 @@ def build_targets(target, anchor_wh, nA, nC, nG):
|
|||
iou_order = torch.argsort(-iou_best) # best to worst
|
||||
|
||||
# Unique anchor selection
|
||||
u = torch.cat((gi, gj, a), 0).view((3, -1))
|
||||
# u = torch.stack((gi, gj, a),0)
|
||||
_, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices
|
||||
# _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy?
|
||||
u = torch.stack((gi, gj, a), 0)[:, iou_order]
|
||||
# _, first_unique = np.unique(u, axis=1, return_index=True) # first unique indices
|
||||
first_unique = return_torch_unique_index(u, torch.unique(u, dim=1)) # torch alternative
|
||||
|
||||
i = iou_order[first_unique]
|
||||
# best anchor must share significant commonality (iou) with target
|
||||
i = i[iou_best[i] > 0.10] # TODO: arbitrary threshold is problematic
|
||||
i = i[iou_best[i] > 0.10] # TODO: examine arbitrary threshold
|
||||
if len(i) == 0:
|
||||
continue
|
||||
|
||||
|
@ -428,6 +427,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
return output
|
||||
|
||||
|
||||
def return_torch_unique_index(u, uv):
|
||||
n = uv.shape[1] # number of columns
|
||||
first_unique = torch.zeros(n, device=u.device).long()
|
||||
for j in range(n):
|
||||
first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0]
|
||||
|
||||
return first_unique
|
||||
|
||||
|
||||
def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
||||
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
|
||||
|
||||
|
|
Loading…
Reference in New Issue