updates
This commit is contained in:
parent
45b7dfc054
commit
2cc805edda
14
test.py
14
test.py
|
@ -14,7 +14,7 @@ def test(cfg,
|
|||
batch_size=16,
|
||||
img_size=416,
|
||||
conf_thres=0.001,
|
||||
iou_thres=0.5,
|
||||
iou_thres=0.5, # for nms
|
||||
save_json=False,
|
||||
model=None,
|
||||
dataloader=None):
|
||||
|
@ -48,11 +48,9 @@ def test(cfg,
|
|||
nc = int(data['classes']) # number of classes
|
||||
path = data['valid'] # path to test images
|
||||
names = load_classes(data['names']) # class names
|
||||
# iou_thres = torch.linspace(0.5, 0.95, 10).to(device) # for mAP@0.5:0.95
|
||||
# iou_thres = iou_thres[0].view(1) # for mAP@0.5
|
||||
if isinstance(iou_thres, float):
|
||||
iou_thres = torch.Tensor([iou_thres]).to(device) # convert to array
|
||||
niou = iou_thres.numel()
|
||||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
||||
iouv = iouv[0].view(1) # for mAP@0.5
|
||||
niou = iouv.numel()
|
||||
|
||||
# Dataloader
|
||||
if dataloader is None:
|
||||
|
@ -145,11 +143,11 @@ def test(cfg,
|
|||
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
||||
|
||||
# Append detections
|
||||
for j in (ious > iou_thres[0]).nonzero():
|
||||
for j in (ious > iouv[0]).nonzero():
|
||||
d = ti[i[j]] # detected target
|
||||
if d not in detected:
|
||||
detected.append(d)
|
||||
correct[pi[j]] = ious[j] > iou_thres # iou_thres is 1xn
|
||||
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
|
||||
if len(detected) == nl: # all targets already located in image
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue