updates
This commit is contained in:
parent
02d6b2f9c5
commit
40221894c2
21
test.py
21
test.py
|
@ -133,18 +133,29 @@ def test(
|
||||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
||||||
|
|
||||||
# Compute statistics
|
# Compute statistics
|
||||||
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy
|
||||||
nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class
|
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||||
if len(stats_np):
|
if len(stats):
|
||||||
p, r, ap, f1, ap_class = ap_per_class(*stats_np)
|
p, r, ap, f1, ap_class = ap_per_class(*stats)
|
||||||
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
|
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
|
||||||
|
|
||||||
|
if any(r > 1):
|
||||||
|
chkpt = {'epoch': -1,
|
||||||
|
'best_loss': None,
|
||||||
|
'model': model.module.state_dict() if type(
|
||||||
|
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||||
|
'optimizer': None}
|
||||||
|
|
||||||
|
# Save problem checkpoint
|
||||||
|
torch.save(chkpt, 'recall_issue.pt')
|
||||||
|
del chkpt
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
pf = '%20s' + '%10.3g' * 6 # print format
|
pf = '%20s' + '%10.3g' * 6 # print format
|
||||||
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
|
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
|
||||||
|
|
||||||
# Print results per class
|
# Print results per class
|
||||||
if nc > 1 and len(stats_np):
|
if nc > 1 and len(stats):
|
||||||
for i, c in enumerate(ap_class):
|
for i, c in enumerate(ap_class):
|
||||||
print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))
|
print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue