This commit is contained in:
Glenn Jocher 2019-12-22 16:05:43 -08:00
parent 0e54731bb8
commit 0e17fb5905
1 changed files with 17 additions and 19 deletions

36
test.py
View File

@ -126,7 +126,7 @@ def test(cfg,
# Assign all predictions as incorrect # Assign all predictions as incorrect
correct = torch.zeros(len(pred), niou) correct = torch.zeros(len(pred), niou)
if nl: if nl:
detected = [] detected = [] # target indices
tcls_tensor = labels[:, 0] tcls_tensor = labels[:, 0]
# target boxes # target boxes
@ -134,26 +134,24 @@ def test(cfg,
tbox[:, [0, 2]] *= width tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height tbox[:, [1, 3]] *= height
# Search for correct predictions # Per target class
for i, (*pbox, _, pcls) in enumerate(pred): for cls in torch.unique(tcls_tensor):
ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
# Break if all targets already located in image # Search for detections
if len(detected) == nl: if len(pi):
break # Prediction to target ious
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
# Continue if predicted class not among image classes # Append detections
if pcls.item() not in tcls: for j in (ious > iou_thres[0]).nonzero():
continue d = ti[i[j]] # detected target
if d not in detected:
# Best iou, index between pred and targets detected.append(d)
m = (pcls == tcls_tensor).nonzero().view(-1) correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn
iou, j = bbox_iou(pbox, tbox[m]).max(0) if len(detected) == nl: # all targets already located in image
m = m[j] break
# Per iou_thres 'correct' vector
if iou > iou_thres[0] and m not in detected:
detected.append(m)
correct[i] = iou > iou_thres
# Append statistics (correct, conf, pcls, tcls) # Append statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))