updates
This commit is contained in:
parent
0e54731bb8
commit
0e17fb5905
36
test.py
36
test.py
|
@ -126,7 +126,7 @@ def test(cfg,
|
|||
# Assign all predictions as incorrect
|
||||
correct = torch.zeros(len(pred), niou)
|
||||
if nl:
|
||||
detected = []
|
||||
detected = [] # target indices
|
||||
tcls_tensor = labels[:, 0]
|
||||
|
||||
# target boxes
|
||||
|
@ -134,27 +134,25 @@ def test(cfg,
|
|||
tbox[:, [0, 2]] *= width
|
||||
tbox[:, [1, 3]] *= height
|
||||
|
||||
# Search for correct predictions
|
||||
for i, (*pbox, _, pcls) in enumerate(pred):
|
||||
# Per target class
|
||||
for cls in torch.unique(tcls_tensor):
|
||||
ti = (cls == tcls_tensor).nonzero().view(-1) # prediction indices
|
||||
pi = (cls == pred[:, 5]).nonzero().view(-1) # target indices
|
||||
|
||||
# Break if all targets already located in image
|
||||
if len(detected) == nl:
|
||||
# Search for detections
|
||||
if len(pi):
|
||||
# Prediction to target ious
|
||||
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
||||
|
||||
# Append detections
|
||||
for j in (ious > iou_thres[0]).nonzero():
|
||||
d = ti[i[j]] # detected target
|
||||
if d not in detected:
|
||||
detected.append(d)
|
||||
correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn
|
||||
if len(detected) == nl: # all targets already located in image
|
||||
break
|
||||
|
||||
# Continue if predicted class not among image classes
|
||||
if pcls.item() not in tcls:
|
||||
continue
|
||||
|
||||
# Best iou, index between pred and targets
|
||||
m = (pcls == tcls_tensor).nonzero().view(-1)
|
||||
iou, j = bbox_iou(pbox, tbox[m]).max(0)
|
||||
m = m[j]
|
||||
|
||||
# Per iou_thres 'correct' vector
|
||||
if iou > iou_thres[0] and m not in detected:
|
||||
detected.append(m)
|
||||
correct[i] = iou > iou_thres
|
||||
|
||||
# Append statistics (correct, conf, pcls, tcls)
|
||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
||||
|
||||
|
|
Loading…
Reference in New Issue