From 545f756090e142412e40589c8c30ad58baade744 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 28 Feb 2019 15:40:30 +0100 Subject: [PATCH] updates --- utils/utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 8d2cf95b..53defd08 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -236,17 +236,17 @@ def build_targets(target, anchor_wh, nA, nC, nG): returns nT, nCorrect, tx, ty, tw, th, tconf, tcls """ nB = len(target) # number of images in batch - nT = [len(x) for x in target] + txy = torch.zeros(nB, nA, nG, nG, 2) # batch size, anchors, grid size twh = torch.zeros(nB, nA, nG, nG, 2) tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0) tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes for b in range(nB): - nTb = nT[b] # number of targets + t = target[b] + nTb = len(t) # number of targets if nTb == 0: continue - t = target[b] gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG @@ -267,14 +267,13 @@ def build_targets(target, anchor_wh, nA, nC, nG): iou_order = torch.argsort(-iou_best) # best to worst # Unique anchor selection - u = torch.cat((gi, gj, a), 0).view((3, -1)) - # u = torch.stack((gi, gj, a),0) - _, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices - # _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy? + u = torch.stack((gi, gj, a), 0)[:, iou_order] + # _, first_unique = np.unique(u, axis=1, return_index=True) # first unique indices + first_unique = return_torch_unique_index(u, torch.unique(u, dim=1)) # torch alternative i = iou_order[first_unique] # best anchor must share significant commonality (iou) with target - i = i[iou_best[i] > 0.10] # TODO: arbitrary threshold is problematic + i = i[iou_best[i] > 0.10] # TODO: examine arbitrary threshold if len(i) == 0: continue @@ -428,6 +427,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): return output +def return_torch_unique_index(u, uv): + n = uv.shape[1] # number of columns + first_unique = torch.zeros(n, device=u.device).long() + for j in range(n): + first_unique[j] = (uv[:, j:j + 1] == u).all(0).nonzero()[0] + + return first_unique + + def strip_optimizer_from_checkpoint(filename='weights/best.pt'): # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)