From 0f225afe330a190d48f2a864bde9c48ba9467b05 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 24 Dec 2019 12:42:22 -0800 Subject: [PATCH] updates --- test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test.py b/test.py index b05be9ad..1461c714 100644 --- a/test.py +++ b/test.py @@ -106,6 +106,9 @@ def test(cfg, # with open('test.txt', 'a') as file: # [file.write('%11.5g' * 7 % tuple(x) + '\n') for x in pred] + # Clip boxes to image bounds + clip_coords(pred, (height, width)) + # Append to pycocotools JSON dictionary if save_json: # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... @@ -120,11 +123,8 @@ def test(cfg, 'bbox': [floatn(x, 3) for x in box[di]], 'score': floatn(d[4], 5)}) - # Clip boxes to image bounds - clip_coords(pred, (height, width)) - # Assign all predictions as incorrect - correct = torch.zeros(len(pred), niou) + correct = torch.zeros(len(pred), niou, dtype=torch.bool) if nl: detected = [] # target indices tcls_tensor = labels[:, 0] @@ -147,7 +147,7 @@ def test(cfg, d = ti[i[j]] # detected target if d not in detected: detected.append(d) - correct[pi[j]] = (ious[j] > iou_thres).float() # iou_thres is 1xn + correct[pi[j]] = ious[j] > iou_thres # iou_thres is 1xn if len(detected) == nl: # all targets already located in image break