This commit is contained in:
Glenn Jocher 2019-12-22 16:05:43 -08:00
parent 0e54731bb8
commit 0e17fb5905
1 changed files with 17 additions and 19 deletions

36
test.py
View File

@ -126,7 +126,7 @@ def test(cfg,
# Assign all predictions as incorrect
correct = torch.zeros(len(pred), niou)
if nl:
detected = []
detected = [] # target indices
tcls_tensor = labels[:, 0]
# target boxes
@ -134,26 +134,24 @@ def test(cfg,
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
# Search for correct predictions
for i, (*pbox, _, pcls) in enumerate(pred):
# Per target class
for cls in torch.unique(tcls_tensor):
ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
# Break if all targets already located in image
if len(detected) == nl:
break
# Search for detections
if len(pi):
# Prediction to target ious
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
# Continue if predicted class not among image classes
if pcls.item() not in tcls:
continue
# Best iou, index between pred and targets
m = (pcls == tcls_tensor).nonzero().view(-1)
iou, j = bbox_iou(pbox, tbox[m]).max(0)
m = m[j]
# Per iou_thres 'correct' vector
if iou > iou_thres[0] and m not in detected:
detected.append(m)
correct[i] = iou > iou_thres
# Append detections
for j in (ious > iou_thres[0]).nonzero():
d = ti[i[j]] # detected target
if d not in detected:
detected.append(d)
correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn
if len(detected) == nl: # all targets already located in image
break
# Append statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))