updates
This commit is contained in:
parent
bfc77ec88f
commit
d65d64bb7e
41
test.py
41
test.py
|
@ -81,14 +81,13 @@ def test(
|
||||||
# Statistics per image
|
# Statistics per image
|
||||||
for si, pred in enumerate(output):
|
for si, pred in enumerate(output):
|
||||||
labels = targets[targets[:, 0] == si, 1:]
|
labels = targets[targets[:, 0] == si, 1:]
|
||||||
correct, detected = [], []
|
nl = len(labels)
|
||||||
tcls = torch.Tensor()
|
tcls = labels[:, 0].tolist() if nl else [] # target class
|
||||||
seen += 1
|
seen += 1
|
||||||
|
|
||||||
if pred is None:
|
if pred is None:
|
||||||
if len(labels):
|
if nl:
|
||||||
tcls = labels[:, 0].cpu() # target classes
|
stats.append(([], torch.Tensor(), torch.Tensor(), tcls))
|
||||||
stats.append((correct, torch.Tensor(), torch.Tensor(), tcls))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Append to pycocotools JSON dictionary
|
# Append to pycocotools JSON dictionary
|
||||||
|
@ -107,14 +106,21 @@ def test(
|
||||||
'score': float(d[4])
|
'score': float(d[4])
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(labels):
|
# Assign all predictions as incorrect
|
||||||
# Extract target boxes as (x1, y1, x2, y2)
|
correct = [0] * len(pred)
|
||||||
|
if nl:
|
||||||
|
detected = []
|
||||||
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
|
tbox = xywh2xyxy(labels[:, 1:5]) * img_size # target boxes
|
||||||
tcls = labels[:, 0] # target classes
|
|
||||||
|
|
||||||
for *pbox, pconf, pcls_conf, pcls in pred:
|
# Search for correct predictions
|
||||||
if pcls not in tcls:
|
for i, (*pbox, pconf, pcls_conf, pcls) in enumerate(pred):
|
||||||
correct.append(0)
|
|
||||||
|
# Break if all targets already located in image
|
||||||
|
if len(detected) == nl:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Continue if predicted class not among image classes
|
||||||
|
if pcls.item() not in tcls:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Best iou, index between pred and targets
|
# Best iou, index between pred and targets
|
||||||
|
@ -122,16 +128,11 @@ def test(
|
||||||
|
|
||||||
# If iou > threshold and class is correct mark as correct
|
# If iou > threshold and class is correct mark as correct
|
||||||
if iou > iou_thres and bi not in detected:
|
if iou > iou_thres and bi not in detected:
|
||||||
correct.append(1)
|
correct[i] = 1
|
||||||
detected.append(bi)
|
detected.append(bi)
|
||||||
else:
|
|
||||||
correct.append(0)
|
|
||||||
else:
|
|
||||||
# If no labels add number of detections as incorrect
|
|
||||||
correct.extend([0] * len(pred))
|
|
||||||
|
|
||||||
# Append Statistics (correct, conf, pcls, tcls)
|
# Append statistics (correct, conf, pcls, tcls)
|
||||||
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls.cpu()))
|
stats.append((correct, pred[:, 4].cpu(), pred[:, 6].cpu(), tcls))
|
||||||
|
|
||||||
# Compute statistics
|
# Compute statistics
|
||||||
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
stats_np = [np.concatenate(x, 0) for x in list(zip(*stats))]
|
||||||
|
@ -177,7 +178,7 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(prog='test.py')
|
parser = argparse.ArgumentParser(prog='test.py')
|
||||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
||||||
parser.add_argument('--data-cfg', type=str, default='data/coco.data', help='coco.data file path')
|
parser.add_argument('--data-cfg', type=str, default='data/coco_10img.data', help='coco.data file path')
|
||||||
parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file')
|
parser.add_argument('--weights', type=str, default='weights/yolov3-spp.weights', help='path to weights file')
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
|
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
|
||||||
|
|
Loading…
Reference in New Issue