diff --git a/test.py b/test.py index 845abfb3..b7d33d37 100644 --- a/test.py +++ b/test.py @@ -124,11 +124,11 @@ def test(cfg, scale_coords(imgs[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape box = xyxy2xywh(box) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner - for di, d in enumerate(pred): + for p, b in zip(pred.tolist(), box.tolist()): jdict.append({'image_id': image_id, - 'category_id': coco91class[int(d[5])], - 'bbox': [round(x, 3) for x in box[di].tolist()], - 'score': round(d[4].item(), 5)}) + 'category_id': coco91class[int(p[5])], + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5)}) # Assign all predictions as incorrect correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)