This commit is contained in:
Glenn Jocher 2019-12-27 10:31:12 -08:00
parent 45b7dfc054
commit 2cc805edda
1 changed files with 6 additions and 8 deletions

14
test.py
View File

@ -14,7 +14,7 @@ def test(cfg,
batch_size=16, batch_size=16,
img_size=416, img_size=416,
conf_thres=0.001, conf_thres=0.001,
iou_thres=0.5, iou_thres=0.5, # for nms
save_json=False, save_json=False,
model=None, model=None,
dataloader=None): dataloader=None):
@ -48,11 +48,9 @@ def test(cfg,
nc = int(data['classes']) # number of classes nc = int(data['classes']) # number of classes
path = data['valid'] # path to test images path = data['valid'] # path to test images
names = load_classes(data['names']) # class names names = load_classes(data['names']) # class names
# iou_thres = torch.linspace(0.5, 0.95, 10).to(device) # for mAP@0.5:0.95 iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
# iou_thres = iou_thres[0].view(1) # for mAP@0.5 iouv = iouv[0].view(1) # for mAP@0.5
if isinstance(iou_thres, float): niou = iouv.numel()
iou_thres = torch.Tensor([iou_thres]).to(device) # convert to array
niou = iou_thres.numel()
# Dataloader # Dataloader
if dataloader is None: if dataloader is None:
@ -145,11 +143,11 @@ def test(cfg,
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
# Append detections # Append detections
for j in (ious > iou_thres[0]).nonzero(): for j in (ious > iouv[0]).nonzero():
d = ti[i[j]] # detected target d = ti[i[j]] # detected target
if d not in detected: if d not in detected:
detected.append(d) detected.append(d)
correct[pi[j]] = ious[j] > iou_thres # iou_thres is 1xn correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
if len(detected) == nl: # all targets already located in image if len(detected) == nl: # all targets already located in image
break break