This commit is contained in:
Glenn Jocher 2019-02-28 15:40:30 +01:00
parent 55c6efbb39
commit 545f756090
1 changed files with 16 additions and 8 deletions

View File

@ -236,17 +236,17 @@ def build_targets(target, anchor_wh, nA, nC, nG):
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
""" """
nB = len(target) # number of images in batch nB = len(target) # number of images in batch
nT = [len(x) for x in target]
txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size
twh = torch.zeros(nB, nA, nG, nG, 2) twh = torch.zeros(nB, nA, nG, nG, 2)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
for b in range(nB): for b in range(nB):
nTb = nT[b] # number of targets t = target[b]
nTb = len(t) # number of targets
if nTb == 0: if nTb == 0:
continue continue
t = target[b]
gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
@ -267,14 +267,13 @@ def build_targets(target, anchor_wh, nA, nC, nG):
iou_order = torch.argsort(-iou_best) # best to worst iou_order = torch.argsort(-iou_best) # best to worst
# Unique anchor selection # Unique anchor selection
u = torch.cat((gi, gj, a), 0).view((3, -1)) u = torch.stack((gi, gj, a), 0)[:, iou_order]
# u = torch.stack((gi, gj, a),0) # _, first_unique = np.unique(u, axis=1, return_index=True) # first unique indices
_, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices first_unique = return_torch_unique_index(u, torch.unique(u, dim=1)) # torch alternative
# _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy?
i = iou_order[first_unique] i = iou_order[first_unique]
# best anchor must share significant commonality (iou) with target # best anchor must share significant commonality (iou) with target
i = i[iou_best[i] > 0.10] # TODO: arbitrary threshold is problematic i = i[iou_best[i] > 0.10] # TODO: examine arbitrary threshold
if len(i) == 0: if len(i) == 0:
continue continue
@ -428,6 +427,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
return output return output
def return_torch_unique_index(u, uv):
n = uv.shape[1] # number of columns
first_unique = torch.zeros(n, device=u.device).long()
for j in range(n):
first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0]
return first_unique
def strip_optimizer_from_checkpoint(filename='weights/best.pt'): def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size) # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)