updates
This commit is contained in:
parent
bf66656b4e
commit
7b13af707d
34
test.py
34
test.py
|
@ -5,7 +5,7 @@ from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog='test.py')
|
parser = argparse.ArgumentParser(prog='test.py')
|
||||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=64, help='size of each image batch')
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
||||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
|
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
|
||||||
parser.add_argument('-weights_path', type=str, default='weights/yolov3.pt', help='path to weights file')
|
parser.add_argument('-weights_path', type=str, default='weights/yolov3.pt', help='path to weights file')
|
||||||
|
@ -53,41 +53,35 @@ def main(opt):
|
||||||
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
|
||||||
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
||||||
for batch_i, (imgs, targets) in enumerate(dataloader):
|
for batch_i, (imgs, targets) in enumerate(dataloader):
|
||||||
imgs = imgs.to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(imgs)
|
output = model(imgs.to(device))
|
||||||
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
|
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
|
||||||
|
|
||||||
# Compute average precision for each sample
|
# Compute average precision for each sample
|
||||||
for sample_i in range(len(targets)):
|
for sample_i, (labels, detections) in enumerate(zip(targets, output)):
|
||||||
correct = []
|
correct = []
|
||||||
|
|
||||||
# Get labels for sample where width is not zero (dummies)
|
|
||||||
annotations = targets[sample_i]
|
|
||||||
# Extract detections
|
|
||||||
detections = output[sample_i]
|
|
||||||
|
|
||||||
if detections is None:
|
if detections is None:
|
||||||
# If there are no detections but there are annotations mask as zero AP
|
# If there are no detections but there are labels mask as zero AP
|
||||||
if annotations.size(0) != 0:
|
if labels.size(0) != 0:
|
||||||
mAPs.append(0)
|
mAPs.append(0), mR.append(0), mP.append(0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get detections sorted by decreasing confidence scores
|
# Get detections sorted by decreasing confidence scores
|
||||||
|
detections = detections.cpu().numpy()
|
||||||
detections = detections[np.argsort(-detections[:, 4])]
|
detections = detections[np.argsort(-detections[:, 4])]
|
||||||
|
|
||||||
# If no annotations add number of detections as incorrect
|
# If no labels add number of detections as incorrect
|
||||||
if annotations.size(0) == 0:
|
if labels.size(0) == 0:
|
||||||
# correct.extend([0 for _ in range(len(detections))])
|
# correct.extend([0 for _ in range(len(detections))])
|
||||||
mAPs.append(0)
|
mAPs.append(0), mR.append(0), mP.append(0)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
target_cls = annotations[:, 0]
|
target_cls = labels[:, 0]
|
||||||
|
|
||||||
# Extract target boxes as (x1, y1, x2, y2)
|
# Extract target boxes as (x1, y1, x2, y2)
|
||||||
target_boxes = xywh2xyxy(annotations[:, 1:5])
|
target_boxes = xywh2xyxy(labels[:, 1:5]) * opt.img_size
|
||||||
target_boxes *= opt.img_size
|
|
||||||
|
|
||||||
detected = []
|
detected = []
|
||||||
for *pred_bbox, conf, obj_conf, obj_pred in detections:
|
for *pred_bbox, conf, obj_conf, obj_pred in detections:
|
||||||
|
@ -98,7 +92,7 @@ def main(opt):
|
||||||
# Extract index of largest overlap
|
# Extract index of largest overlap
|
||||||
best_i = np.argmax(iou)
|
best_i = np.argmax(iou)
|
||||||
# If overlap exceeds threshold and classification is correct mark as correct
|
# If overlap exceeds threshold and classification is correct mark as correct
|
||||||
if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
|
if iou[best_i] > opt.iou_thres and obj_pred == labels[best_i, 0] and best_i not in detected:
|
||||||
correct.append(1)
|
correct.append(1)
|
||||||
detected.append(best_i)
|
detected.append(best_i)
|
||||||
else:
|
else:
|
||||||
|
@ -123,7 +117,7 @@ def main(opt):
|
||||||
mean_P = np.mean(mP)
|
mean_P = np.mean(mP)
|
||||||
|
|
||||||
# Print image mAP and running mean mAP
|
# Print image mAP and running mean mAP
|
||||||
print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), len(dataloader) * opt.batch_size, mean_P, mean_R, mean_mAP))
|
print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), dataloader.nF, mean_P, mean_R, mean_mAP))
|
||||||
|
|
||||||
# Print mAP per class
|
# Print mAP per class
|
||||||
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP') + '\n\nmAP Per Class:')
|
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP') + '\n\nmAP Per Class:')
|
||||||
|
|
Loading…
Reference in New Issue