This commit is contained in:
Glenn Jocher 2018-09-10 16:31:56 +02:00
parent 751e02de3e
commit 34144aabe3
1 changed files with 7 additions and 2 deletions

View File

@ -99,15 +99,20 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes # Find unique classes
unique_classes = target_cls #np.unique(np.concatenate((pred_cls, target_cls), 0)) unique_classes = np.unique(np.concatenate((pred_cls, target_cls), 0))
# Create Precision-Recall curve and compute AP for each class # Create Precision-Recall curve and compute AP for each class
ap = [] ap = []
for c in unique_classes: for c in unique_classes:
i = pred_cls == c i = pred_cls == c
n_gt = sum(target_cls == c) # Number of ground truth objects n_gt = sum(target_cls == c) # Number of ground truth objects
n_p = sum(i) # Number of predicted objects
if sum(i) == 0: if (n_p == 0) and (n_gt == 0):
continue
elif (np == 0) and (n_gt > 0):
ap.append(0)
elif (n_p > 0) and (n_gt == 0):
ap.append(0) ap.append(0)
else: else:
# Accumulate FPs and TPs # Accumulate FPs and TPs