This commit is contained in:
Glenn Jocher 2018-09-10 16:31:56 +02:00
parent 751e02de3e
commit 34144aabe3
1 changed files with 7 additions and 2 deletions

View File

@ -99,15 +99,20 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes
unique_classes = target_cls #np.unique(np.concatenate((pred_cls, target_cls), 0))
unique_classes = np.unique(np.concatenate((pred_cls, target_cls), 0))
# Create Precision-Recall curve and compute AP for each class
ap = []
for c in unique_classes:
i = pred_cls == c
n_gt = sum(target_cls == c) # Number of ground truth objects
n_p = sum(i) # Number of predicted objects
if sum(i) == 0:
if (n_p == 0) and (n_gt == 0):
continue
elif (np == 0) and (n_gt > 0):
ap.append(0)
elif (n_p > 0) and (n_gt == 0):
ap.append(0)
else:
# Accumulate FPs and TPs