updates
This commit is contained in:
parent
0e54731bb8
commit
0e17fb5905
36
test.py
36
test.py
|
@ -126,7 +126,7 @@ def test(cfg,
|
||||||
# Assign all predictions as incorrect
|
# Assign all predictions as incorrect
|
||||||
correct = torch.zeros(len(pred), niou)
|
correct = torch.zeros(len(pred), niou)
|
||||||
if nl:
|
if nl:
|
||||||
detected = []
|
detected = [] # target indices
|
||||||
tcls_tensor = labels[:, 0]
|
tcls_tensor = labels[:, 0]
|
||||||
|
|
||||||
# target boxes
|
# target boxes
|
||||||
|
@ -134,27 +134,25 @@ def test(cfg,
|
||||||
tbox[:, [0, 2]] *= width
|
tbox[:, [0, 2]] *= width
|
||||||
tbox[:, [1, 3]] *= height
|
tbox[:, [1, 3]] *= height
|
||||||
|
|
||||||
# Search for correct predictions
|
# Per target class
|
||||||
for i, (*pbox, _, pcls) in enumerate(pred):
|
for cls in torch.unique(tcls_tensor):
|
||||||
|
ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices
|
||||||
|
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
|
||||||
|
|
||||||
# Break if all targets already located in image
|
# Search for detections
|
||||||
if len(detected) == nl:
|
if len(pi):
|
||||||
|
# Prediction to target ious
|
||||||
|
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
||||||
|
|
||||||
|
# Append detections
|
||||||
|
for j in (ious > iou_thres[0]).nonzero():
|
||||||
|
d = ti[i[j]] # detected target
|
||||||
|
if d not in detected:
|
||||||
|
detected.append(d)
|
||||||
|
correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn
|
||||||
|
if len(detected) == nl: # all targets already located in image
|
||||||
break
|
break
|
||||||
|
|
||||||
# Continue if predicted class not among image classes
|
|
||||||
if pcls.item() not in tcls:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Best iou, index between pred and targets
|
|
||||||
m = (pcls == tcls_tensor).nonzero().view(-1)
|
|
||||||
iou, j = bbox_iou(pbox, tbox[m]).max(0)
|
|
||||||
m = m[j]
|
|
||||||
|
|
||||||
# Per iou_thres 'correct' vector
|
|
||||||
if iou > iou_thres[0] and m not in detected:
|
|
||||||
detected.append(m)
|
|
||||||
correct[i] = iou > iou_thres
|
|
||||||
|
|
||||||
# Append statistics (correct, conf, pcls, tcls)
|
# Append statistics (correct, conf, pcls, tcls)
|
||||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue