From d65d64bb7e28b640582febf0b65fb41804e28156 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 10 Apr 2019 16:17:08 +0200 Subject: [PATCH] updates --- test.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/test.py b/test.py index 98a53e34..0e04be8e 100644 --- a/test.py +++ b/test.py @@ -81,14 +81,13 @@ def test( # Statistics per image for si, pred in enumerate(output): labels = targets[targets[:, 0] == si, 1:] - correct, detected = [], [] - tcls = torch.Tensor() + nl = len(labels) + tcls = labels[:, 0].tolist() if nl else [] # target class seen += 1 if pred is None: - if len(labels): - tcls = labels[:, 0].cpu() # target classes - stats.append((correct, torch.Tensor(), torch.Tensor(), tcls)) + if nl: + stats.append(([], torch.Tensor(), torch.Tensor(), tcls)) continue # Append to pycocotools JSON dictionary @@ -107,14 +106,21 @@ def test( 'score': float(d[4]) }) - if len(labels): - # Extract target boxes as (x1, y1, x2, y2) + # Assign all predictions as incorrect + correct = [0] * len(pred) + if nl: + detected = [] tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes - tcls = labels[:, 0] # target classes - for *pbox, pconf, pcls_conf, pcls in pred: - if pcls not in tcls: - correct.append(0) + # Search for correct predictions + for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred): + + # Break if all targets already located in image + if len(detected) == nl: + break + + # Continue if predicted class not among image classes + if pcls.item() not in tcls: continue # Best iou, index between pred and targets @@ -122,16 +128,11 @@ def test( # If iou > threshold and class is correct mark as correct if iou > iou_thres and bi not in detected: - correct.append(1) + correct[i] = 1 detected.append(bi) - else: - correct.append(0) - else: - # If no labels add number of detections as incorrect - correct.extend([0] * len(pred)) - # Append Statistics (correct, conf, pcls, tcls) - stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu())) + # Append statistics (correct, conf, pcls, tcls) + stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls)) # Compute statistics stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))] @@ -177,7 +178,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(prog='test.py') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') - parser.add_argument('--data-cfg', type=str, default='data/coco.data', help='coco.data file path') + parser.add_argument('--data-cfg', type=str, default='data/coco_10img.data', help='coco.data file path') parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file') parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected') parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')