This commit is contained in:
Glenn Jocher 2019-04-10 16:17:08 +02:00
parent bfc77ec88f
commit d65d64bb7e
1 changed files with 21 additions and 20 deletions

41
test.py
View File

@ -81,14 +81,13 @@ def test(
# Statistics per image # Statistics per image
for si, pred in enumerate(output): for si, pred in enumerate(output):
labels = targets[targets[:, 0] == si, 1:] labels = targets[targets[:, 0] == si, 1:]
correct, detected = [], [] nl = len(labels)
tcls = torch.Tensor() tcls = labels[:, 0].tolist() if nl else [] # target class
seen += 1 seen += 1
if pred is None: if pred is None:
if len(labels): if nl:
tcls = labels[:, 0].cpu() # target classes stats.append(([], torch.Tensor(), torch.Tensor(), tcls))
stats.append((correct, torch.Tensor(), torch.Tensor(), tcls))
continue continue
# Append to pycocotools JSON dictionary # Append to pycocotools JSON dictionary
@ -107,14 +106,21 @@ def test(
'score': float(d[4]) 'score': float(d[4])
}) })
if len(labels): # Assign all predictions as incorrect
# Extract target boxes as (x1, y1, x2, y2) correct = [0] * len(pred)
if nl:
detected = []
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
tcls = labels[:, 0] # target classes
for *pbox, pconf, pcls_conf, pcls in pred: # Search for correct predictions
if pcls not in tcls: for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred):
correct.append(0)
# Break if all targets already located in image
if len(detected) == nl:
break
# Continue if predicted class not among image classes
if pcls.item() not in tcls:
continue continue
# Best iou, index between pred and targets # Best iou, index between pred and targets
@ -122,16 +128,11 @@ def test(
# If iou > threshold and class is correct mark as correct # If iou > threshold and class is correct mark as correct
if iou > iou_thres and bi not in detected: if iou > iou_thres and bi not in detected:
correct.append(1) correct[i] = 1
detected.append(bi) detected.append(bi)
else:
correct.append(0)
else:
# If no labels add number of detections as incorrect
correct.extend([0] * len(pred))
# Append Statistics (correct, conf, pcls, tcls) # Append statistics (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu())) stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
# Compute statistics # Compute statistics
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))] stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
@ -177,7 +178,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='test.py') parser = argparse.ArgumentParser(prog='test.py')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
parser.add_argument('--data-cfg', type=str, default='data/coco.data', help='coco.data file path') parser.add_argument('--data-cfg', type=str, default='data/coco_10img.data', help='coco.data file path')
parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file') parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file')
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected') parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')