updates
This commit is contained in:
parent
02d6b2f9c5
commit
40221894c2
21
test.py
21
test.py
|
@ -133,18 +133,29 @@ def test(
|
|||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
||||
|
||||
# Compute statistics
|
||||
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
||||
nt = np.bincount(stats_np[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||
if len(stats_np):
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats_np)
|
||||
stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy
|
||||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||
if len(stats):
|
||||
p, r, ap, f1, ap_class = ap_per_class(*stats)
|
||||
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
|
||||
|
||||
if any(r > 1):
|
||||
chkpt = {'epoch': -1,
|
||||
'best_loss': None,
|
||||
'model': model.module.state_dict() if type(
|
||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||
'optimizer': None}
|
||||
|
||||
# Save problem checkpoint
|
||||
torch.save(chkpt, 'recall_issue.pt')
|
||||
del chkpt
|
||||
|
||||
# Print results
|
||||
pf = '%20s' + '%10.3g' * 6 # print format
|
||||
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
|
||||
|
||||
# Print results per class
|
||||
if nc > 1 and len(stats_np):
|
||||
if nc > 1 and len(stats):
|
||||
for i, c in enumerate(ap_class):
|
||||
print(pf % (names[c], seen, nt[c], p[i], r[i], ap[i], f1[i]))
|
||||
|
||||
|
|
Loading…
Reference in New Issue