updates
This commit is contained in:
parent
dc43968918
commit
9d54a268c9
4
test.py
4
test.py
|
@ -144,10 +144,12 @@ def test(cfg,
|
||||||
|
|
||||||
# Compute statistics
|
# Compute statistics
|
||||||
stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy
|
stats = [np.concatenate(x, 0) for x in list(zip(*stats))] # to numpy
|
||||||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
|
||||||
if len(stats):
|
if len(stats):
|
||||||
p, r, ap, f1, ap_class = ap_per_class(*stats)
|
p, r, ap, f1, ap_class = ap_per_class(*stats)
|
||||||
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
|
mp, mr, map, mf1 = p.mean(), r.mean(), ap.mean(), f1.mean()
|
||||||
|
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
||||||
|
else:
|
||||||
|
nt = torch.zeros(1)
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
pf = '%30s' + '%10.3g' * 6 # print format
|
pf = '%30s' + '%10.3g' * 6 # print format
|
||||||
|
|
Loading…
Reference in New Issue