From 0e17fb5905ec5a335f86c5b17a945691c09d568e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 22 Dec 2019 16:05:43 -0800 Subject: [PATCH] updates --- test.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/test.py b/test.py index bdcf8cf5..dc6c9f50 100644 --- a/test.py +++ b/test.py @@ -126,7 +126,7 @@ def test(cfg, # Assign all predictions as incorrect correct = torch.zeros(len(pred), niou) if nl: - detected = [] + detected = [] # target indices tcls_tensor = labels[:, 0] # target boxes @@ -134,26 +134,24 @@ def test(cfg, tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height - # Search for correct predictions - for i, (*pbox, _, pcls) in enumerate(pred): + # Per target class + for cls in torch.unique(tcls_tensor): + ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices + pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices - # Break if all targets already located in image - if len(detected) == nl: - break + # Search for detections + if len(pi): + # Prediction to target ious + ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices - # Continue if predicted class not among image classes - if pcls.item() not in tcls: - continue - - # Best iou, index between pred and targets - m = (pcls == tcls_tensor).nonzero().view(-1) - iou, j = bbox_iou(pbox, tbox[m]).max(0) - m = m[j] - - # Per iou_thres 'correct' vector - if iou > iou_thres[0] and m not in detected: - detected.append(m) - correct[i] = iou > iou_thres + # Append detections + for j in (ious > iou_thres[0]).nonzero(): + d = ti[i[j]] # detected target + if d not in detected: + detected.append(d) + correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn + if len(detected) == nl: # all targets already located in image + break # Append statistics (correct, conf, pcls, tcls) stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))