diff --git a/test.py b/test.py index ca59d4f8..8bc5e1e5 100644 --- a/test.py +++ b/test.py @@ -144,10 +144,12 @@ def test(cfg, # Compute statistics stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy - nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class if len(stats): p, r, ap, f1, ap_class = ap_per_class(*stats) mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean() + nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class + else: + nt = torch.zeros(1) # Print results pf = '%30s' + '%10.3g' * 6 # print format