From d336e0053df5c29341634394df0f97c02423cbd7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 10 Oct 2018 17:07:21 +0200 Subject: [PATCH] per-class mAP report --- models.py | 3 +++ test.py | 17 +++++++++++++++-- utils/utils.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index 512d89c0..db2395b2 100755 --- a/models.py +++ b/models.py @@ -99,6 +99,8 @@ class YOLOLayer(nn.Module): self.scaled_anchors = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors]) self.anchor_w = self.scaled_anchors[:, 0:1].view((1, nA, 1, 1)) self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1)) + self.weights = class_weights() + def forward(self, p, targets=None, requestPrecision=False): FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor @@ -110,6 +112,7 @@ class YOLOLayer(nn.Module): if p.is_cuda and not self.grid_x.is_cuda: self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda() + self.weights = self.weights.cuda() # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh) p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction diff --git a/test.py b/test.py index b88faf0a..e8ca3163 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,5 @@ import argparse + from models import * from utils.datasets import * from utils.utils import * @@ -48,9 +49,11 @@ dataloader = load_images_and_labels(test_path, batch_size=opt.batch_size, img_si print('Compute mAP...') +nC = 80 # number of classes correct = 0 targets = None outputs, mAPs, TP, confidence, pred_class, target_class = [], [], [], [], [], [] +AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) for batch_i, (imgs, targets) in enumerate(dataloader): imgs = imgs.to(device) @@ -79,7 +82,7 @@ for batch_i, (imgs, targets) in enumerate(dataloader): # If no annotations add number of detections as incorrect if annotations.size(0) == 0: target_cls = [] - #correct.extend([0 for _ in range(len(detections))]) + # correct.extend([0 for _ in range(len(detections))]) mAPs.append(0) continue else: @@ -105,7 +108,11 @@ for batch_i, (imgs, targets) in enumerate(dataloader): correct.append(0) # Compute Average Precision (AP) per class - AP = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls) + AP, AP_class = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls) + + # Accumulate AP per class + AP_accum_count += np.bincount(AP_class, minlength=nC) + AP_accum += np.bincount(AP_class, minlength=nC, weights=AP) # Compute mean AP for this image mAP = AP.mean() @@ -116,4 +123,10 @@ for batch_i, (imgs, targets) in enumerate(dataloader): # Print image mAP and running mean mAP print('+ Sample [%d/%d] AP: %.4f (%.4f)' % (len(mAPs), len(dataloader) * opt.batch_size, mAP, np.mean(mAPs))) +# Print mAP per class +classes = load_classes(opt.class_path) # Extracts class labels from file +for i, c in enumerate(classes): + print('%15s: %-.4f' % (c, AP_accum[i] / AP_accum_count[i])) + +# Print mAP print('Mean Average Precision: %.4f' % np.mean(mAPs)) diff --git a/utils/utils.py b/utils/utils.py index 1b1c67a1..8b525437 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -130,7 +130,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls): # AP from recall-precision curve ap.append(compute_ap(recall, precision)) - return np.array(ap) + return np.array(ap), unique_classes.astype('int32') def compute_ap(recall, precision):