updates
This commit is contained in:
parent
45b7dfc054
commit
2cc805edda
14
test.py
14
test.py
|
@ -14,7 +14,7 @@ def test(cfg,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
conf_thres=0.001,
|
conf_thres=0.001,
|
||||||
iou_thres=0.5,
|
iou_thres=0.5, # for nms
|
||||||
save_json=False,
|
save_json=False,
|
||||||
model=None,
|
model=None,
|
||||||
dataloader=None):
|
dataloader=None):
|
||||||
|
@ -48,11 +48,9 @@ def test(cfg,
|
||||||
nc = int(data['classes']) # number of classes
|
nc = int(data['classes']) # number of classes
|
||||||
path = data['valid'] # path to test images
|
path = data['valid'] # path to test images
|
||||||
names = load_classes(data['names']) # class names
|
names = load_classes(data['names']) # class names
|
||||||
# iou_thres = torch.linspace(0.5, 0.95, 10).to(device) # for mAP@0.5:0.95
|
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
|
||||||
# iou_thres = iou_thres[0].view(1) # for mAP@0.5
|
iouv = iouv[0].view(1) # for mAP@0.5
|
||||||
if isinstance(iou_thres, float):
|
niou = iouv.numel()
|
||||||
iou_thres = torch.Tensor([iou_thres]).to(device) # convert to array
|
|
||||||
niou = iou_thres.numel()
|
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
if dataloader is None:
|
if dataloader is None:
|
||||||
|
@ -145,11 +143,11 @@ def test(cfg,
|
||||||
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
|
||||||
|
|
||||||
# Append detections
|
# Append detections
|
||||||
for j in (ious > iou_thres[0]).nonzero():
|
for j in (ious > iouv[0]).nonzero():
|
||||||
d = ti[i[j]] # detected target
|
d = ti[i[j]] # detected target
|
||||||
if d not in detected:
|
if d not in detected:
|
||||||
detected.append(d)
|
detected.append(d)
|
||||||
correct[pi[j]] = ious[j] > iou_thres # iou_thres is 1xn
|
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
|
||||||
if len(detected) == nl: # all targets already located in image
|
if len(detected) == nl: # all targets already located in image
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue