diff --git a/test.py b/test.py index 642fdbcf..7cd13011 100644 --- a/test.py +++ b/test.py @@ -50,7 +50,9 @@ def test(cfg, names = load_classes(data['names']) # class names # iou_thres = torch.linspace(0.5, 0.95, 10).to(device) # for mAP@0.5:0.95 # iou_thres = iou_thres[0].view(1) # for mAP@0.5 - niou = 1 # len(iou_thres) + if isinstance(iou_thres, float): + iou_thres = torch.Tensor([iou_thres]).to(device) # convert to array + niou = iou_thres.numel() # Dataloader if dataloader is None: