This commit is contained in:
Glenn Jocher 2019-04-18 21:18:54 +02:00
parent 02d6b2f9c5
commit 40221894c2
1 changed files with 16 additions and 5 deletions

21
test.py
View File

@ -133,18 +133,29 @@ def test(
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
# Compute statistics # Compute statistics
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))] stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy
nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
if len(stats_np): if len(stats):
p, r, ap, f1, ap_class = ap_per_class(*stats_np) p, r, ap, f1, ap_class = ap_per_class(*stats)
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean() mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
if any(r > 1):
chkpt = {'epoch': -1,
'best_loss': None,
'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': None}
# Save problem checkpoint
torch.save(chkpt, 'recall_issue.pt')
del chkpt
# Print results # Print results
pf = '%20s' + '%10.3g' * 6 # print format pf = '%20s' + '%10.3g' * 6 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n') print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
# Print results per class # Print results per class
if nc > 1 and len(stats_np): if nc > 1 and len(stats):
for i, c in enumerate(ap_class): for i, c in enumerate(ap_class):
print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i])) print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))