diff --git a/test.py b/test.py index 7cd13011..c5e166bb 100644 --- a/test.py +++ b/test.py @@ -14,7 +14,7 @@ def test(cfg, batch_size=16, img_size=416, conf_thres=0.001, - iou_thres=0.5, + iou_thres=0.5, # for nms save_json=False, model=None, dataloader=None): @@ -48,11 +48,9 @@ def test(cfg, nc = int(data['classes']) # number of classes path = data['valid'] # path to test images names = load_classes(data['names']) # class names - # iou_thres = torch.linspace(0.5, 0.95, 10).to(device) # for mAP@0.5:0.95 - # iou_thres = iou_thres[0].view(1) # for mAP@0.5 - if isinstance(iou_thres, float): - iou_thres = torch.Tensor([iou_thres]).to(device) # convert to array - niou = iou_thres.numel() + iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 + iouv = iouv[0].view(1) # for mAP@0.5 + niou = iouv.numel() # Dataloader if dataloader is None: @@ -145,11 +143,11 @@ def test(cfg, ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices # Append detections - for j in (ious > iou_thres[0]).nonzero(): + for j in (ious > iouv[0]).nonzero(): d = ti[i[j]] # detected target if d not in detected: detected.append(d) - correct[pi[j]] = ious[j] > iou_thres # iou_thres is 1xn + correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn if len(detected) == nl: # all targets already located in image break