This commit is contained in:
Glenn Jocher 2018-11-23 18:09:47 +01:00
parent bf66656b4e
commit 7b13af707d
1 changed files with 14 additions and 20 deletions

34
test.py
View File

@ -5,7 +5,7 @@ from utils.datasets import *
from utils.utils import *
parser = argparse.ArgumentParser(prog='test.py')
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
parser.add_argument('-batch_size', type=int, default=64, help='size of each image batch')
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
parser.add_argument('-weights_path', type=str, default='weights/yolov3.pt', help='path to weights file')
@ -53,41 +53,35 @@ def main(opt):
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
for batch_i, (imgs, targets) in enumerate(dataloader):
imgs = imgs.to(device)
with torch.no_grad():
output = model(imgs)
output = model(imgs.to(device))
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
# Compute average precision for each sample
for sample_i in range(len(targets)):
for sample_i, (labels, detections) in enumerate(zip(targets, output)):
correct = []
# Get labels for sample where width is not zero (dummies)
annotations = targets[sample_i]
# Extract detections
detections = output[sample_i]
if detections is None:
# If there are no detections but there are annotations mask as zero AP
if annotations.size(0) != 0:
mAPs.append(0)
# If there are no detections but there are labels mask as zero AP
if labels.size(0) != 0:
mAPs.append(0), mR.append(0), mP.append(0)
continue
# Get detections sorted by decreasing confidence scores
detections = detections.cpu().numpy()
detections = detections[np.argsort(-detections[:, 4])]
# If no annotations add number of detections as incorrect
if annotations.size(0) == 0:
# If no labels add number of detections as incorrect
if labels.size(0) == 0:
# correct.extend([0 for _ in range(len(detections))])
mAPs.append(0)
mAPs.append(0), mR.append(0), mP.append(0)
continue
else:
target_cls = annotations[:, 0]
target_cls = labels[:, 0]
# Extract target boxes as (x1, y1, x2, y2)
target_boxes = xywh2xyxy(annotations[:, 1:5])
target_boxes *= opt.img_size
target_boxes = xywh2xyxy(labels[:, 1:5]) * opt.img_size
detected = []
for *pred_bbox, conf, obj_conf, obj_pred in detections:
@ -98,7 +92,7 @@ def main(opt):
# Extract index of largest overlap
best_i = np.argmax(iou)
# If overlap exceeds threshold and classification is correct mark as correct
if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
if iou[best_i] > opt.iou_thres and obj_pred == labels[best_i, 0] and best_i not in detected:
correct.append(1)
detected.append(best_i)
else:
@ -123,7 +117,7 @@ def main(opt):
mean_P = np.mean(mP)
# Print image mAP and running mean mAP
print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), len(dataloader) * opt.batch_size, mean_P, mean_R, mean_mAP))
print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), dataloader.nF, mean_P, mean_R, mean_mAP))
# Print mAP per class
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP') + '\n\nmAP Per Class:')