mAP recorded during training
This commit is contained in:
parent
9dbc3ec1c4
commit
45c5567723
180
test.py
180
test.py
|
@ -11,7 +11,7 @@ parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help
|
|||
parser.add_argument('-weights_path', type=str, default='weights/yolov3.pt', help='path to weights file')
|
||||
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
|
||||
parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
||||
parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold')
|
||||
parser.add_argument('-conf_thres', type=float, default=0.3, help='object confidence threshold')
|
||||
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
|
||||
parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
|
||||
parser.add_argument('-img_size', type=int, default=416, help='size of each image dimension')
|
||||
|
@ -21,112 +21,118 @@ print(opt)
|
|||
cuda = torch.cuda.is_available()
|
||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
# Configure run
|
||||
data_config = parse_data_config(opt.data_config_path)
|
||||
num_classes = int(data_config['classes'])
|
||||
if platform == 'darwin': # MacOS (local)
|
||||
test_path = data_config['valid']
|
||||
else: # linux (cloud, i.e. gcp)
|
||||
test_path = '../coco/5k.part'
|
||||
|
||||
# Initiate model
|
||||
model = Darknet(opt.cfg, opt.img_size)
|
||||
def main(opt):
|
||||
# Configure run
|
||||
data_config = parse_data_config(opt.data_config_path)
|
||||
nC = int(data_config['classes']) # number of classes (80 for COCO)
|
||||
if platform == 'darwin': # MacOS (local)
|
||||
test_path = data_config['valid']
|
||||
else: # linux (cloud, i.e. gcp)
|
||||
test_path = '../coco/5k.part'
|
||||
|
||||
# Load weights
|
||||
if opt.weights_path.endswith('.weights'): # darknet format
|
||||
load_weights(model, opt.weights_path)
|
||||
elif opt.weights_path.endswith('.pt'): # pytorch format
|
||||
checkpoint = torch.load(opt.weights_path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
del checkpoint
|
||||
# Initiate model
|
||||
model = Darknet(opt.cfg, opt.img_size)
|
||||
|
||||
model.to(device).eval()
|
||||
# Load weights
|
||||
if opt.weights_path.endswith('.weights'): # darknet format
|
||||
load_weights(model, opt.weights_path)
|
||||
elif opt.weights_path.endswith('.pt'): # pytorch format
|
||||
checkpoint = torch.load(opt.weights_path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
del checkpoint
|
||||
|
||||
# Get dataloader
|
||||
# dataset = load_images_with_labels(test_path)
|
||||
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)
|
||||
dataloader = load_images_and_labels(test_path, batch_size=opt.batch_size, img_size=opt.img_size)
|
||||
model.to(device).eval()
|
||||
|
||||
print('Compute mAP...')
|
||||
# Get dataloader
|
||||
# dataset = load_images_with_labels(test_path)
|
||||
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)
|
||||
dataloader = load_images_and_labels(test_path, batch_size=opt.batch_size, img_size=opt.img_size)
|
||||
|
||||
nC = 80 # number of classes
|
||||
correct = 0
|
||||
targets = None
|
||||
outputs, mAPs, TP, confidence, pred_class, target_class = [], [], [], [], [], []
|
||||
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
||||
for batch_i, (imgs, targets) in enumerate(dataloader):
|
||||
imgs = imgs.to(device)
|
||||
print('Compute mAP...')
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(imgs)
|
||||
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
|
||||
mAP = 0
|
||||
outputs, mAPs, TP, confidence, pred_class, target_class = [], [], [], [], [], []
|
||||
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
|
||||
for batch_i, (imgs, targets) in enumerate(dataloader):
|
||||
imgs = imgs.to(device)
|
||||
|
||||
# Compute average precision for each sample
|
||||
for sample_i in range(len(targets)):
|
||||
correct = []
|
||||
with torch.no_grad():
|
||||
output = model(imgs)
|
||||
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
|
||||
|
||||
# Get labels for sample where width is not zero (dummies)
|
||||
annotations = targets[sample_i]
|
||||
# Extract detections
|
||||
detections = output[sample_i]
|
||||
# Compute average precision for each sample
|
||||
for sample_i in range(len(targets)):
|
||||
correct = []
|
||||
|
||||
if detections is None:
|
||||
# If there are no detections but there are annotations mask as zero AP
|
||||
if annotations.size(0) != 0:
|
||||
# Get labels for sample where width is not zero (dummies)
|
||||
annotations = targets[sample_i]
|
||||
# Extract detections
|
||||
detections = output[sample_i]
|
||||
|
||||
if detections is None:
|
||||
# If there are no detections but there are annotations mask as zero AP
|
||||
if annotations.size(0) != 0:
|
||||
mAPs.append(0)
|
||||
continue
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
detections = detections[np.argsort(-detections[:, 4])]
|
||||
|
||||
# If no annotations add number of detections as incorrect
|
||||
if annotations.size(0) == 0:
|
||||
# correct.extend([0 for _ in range(len(detections))])
|
||||
mAPs.append(0)
|
||||
continue
|
||||
continue
|
||||
else:
|
||||
target_cls = annotations[:, 0]
|
||||
|
||||
# Get detections sorted by decreasing confidence scores
|
||||
detections = detections[np.argsort(-detections[:, 4])]
|
||||
# Extract target boxes as (x1, y1, x2, y2)
|
||||
target_boxes = xywh2xyxy(annotations[:, 1:5])
|
||||
target_boxes *= opt.img_size
|
||||
|
||||
# If no annotations add number of detections as incorrect
|
||||
if annotations.size(0) == 0:
|
||||
target_cls = []
|
||||
# correct.extend([0 for _ in range(len(detections))])
|
||||
mAPs.append(0)
|
||||
continue
|
||||
else:
|
||||
target_cls = annotations[:, 0]
|
||||
detected = []
|
||||
for *pred_bbox, conf, obj_conf, obj_pred in detections:
|
||||
|
||||
# Extract target boxes as (x1, y1, x2, y2)
|
||||
target_boxes = xywh2xyxy(annotations[:, 1:5])
|
||||
target_boxes *= opt.img_size
|
||||
pred_bbox = torch.FloatTensor(pred_bbox).view(1, -1)
|
||||
# Compute iou with target boxes
|
||||
iou = bbox_iou(pred_bbox, target_boxes)
|
||||
# Extract index of largest overlap
|
||||
best_i = np.argmax(iou)
|
||||
# If overlap exceeds threshold and classification is correct mark as correct
|
||||
if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
|
||||
correct.append(1)
|
||||
detected.append(best_i)
|
||||
else:
|
||||
correct.append(0)
|
||||
|
||||
detected = []
|
||||
for *pred_bbox, conf, obj_conf, obj_pred in detections:
|
||||
# Compute Average Precision (AP) per class
|
||||
AP, AP_class = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6],
|
||||
target_cls=target_cls)
|
||||
|
||||
pred_bbox = torch.FloatTensor(pred_bbox).view(1, -1)
|
||||
# Compute iou with target boxes
|
||||
iou = bbox_iou(pred_bbox, target_boxes)
|
||||
# Extract index of largest overlap
|
||||
best_i = np.argmax(iou)
|
||||
# If overlap exceeds threshold and classification is correct mark as correct
|
||||
if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
|
||||
correct.append(1)
|
||||
detected.append(best_i)
|
||||
else:
|
||||
correct.append(0)
|
||||
# Accumulate AP per class
|
||||
AP_accum_count += np.bincount(AP_class, minlength=nC)
|
||||
AP_accum += np.bincount(AP_class, minlength=nC, weights=AP)
|
||||
|
||||
# Compute Average Precision (AP) per class
|
||||
AP, AP_class = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls)
|
||||
# Compute mean AP for this image
|
||||
mAP = AP.mean()
|
||||
|
||||
# Accumulate AP per class
|
||||
AP_accum_count += np.bincount(AP_class, minlength=nC)
|
||||
AP_accum += np.bincount(AP_class, minlength=nC, weights=AP)
|
||||
# Append image mAP to list
|
||||
mAPs.append(mAP)
|
||||
|
||||
# Compute mean AP for this image
|
||||
mAP = AP.mean()
|
||||
# Print image mAP and running mean mAP
|
||||
print(
|
||||
'+ Sample [%d/%d] AP: %.4f (%.4f)' % (len(mAPs), len(dataloader) * opt.batch_size, mAP, np.mean(mAPs)))
|
||||
|
||||
# Append image mAP to list
|
||||
mAPs.append(mAP)
|
||||
# Print mAP per class
|
||||
classes = load_classes(opt.class_path) # Extracts class labels from file
|
||||
for i, c in enumerate(classes):
|
||||
print('%15s: %-.4f' % (c, AP_accum[i] / AP_accum_count[i]))
|
||||
|
||||
# Print image mAP and running mean mAP
|
||||
print('+ Sample [%d/%d] AP: %.4f (%.4f)' % (len(mAPs), len(dataloader) * opt.batch_size, mAP, np.mean(mAPs)))
|
||||
# Print mAP
|
||||
print('Mean Average Precision: %.4f' % np.mean(mAPs))
|
||||
return mAP
|
||||
|
||||
# Print mAP per class
|
||||
classes = load_classes(opt.class_path) # Extracts class labels from file
|
||||
for i, c in enumerate(classes):
|
||||
print('%15s: %-.4f' % (c, AP_accum[i] / AP_accum_count[i]))
|
||||
|
||||
# Print mAP
|
||||
print('Mean Average Precision: %.4f' % np.mean(mAPs))
|
||||
if __name__ == '__main__':
|
||||
mAP = main(opt)
|
||||
|
|
19
train.py
19
train.py
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import time
|
||||
import test
|
||||
|
||||
from models import *
|
||||
from utils.datasets import *
|
||||
|
@ -103,10 +104,10 @@ def main(opt):
|
|||
# scheduler.step()
|
||||
|
||||
# Update scheduler (manual) at 0, 54, 61 epochs to 1e-3, 1e-4, 1e-5
|
||||
if epoch < 50:
|
||||
lr = 1e-4
|
||||
else:
|
||||
if epoch > 50:
|
||||
lr = 1e-5
|
||||
else:
|
||||
lr = 1e-4
|
||||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr
|
||||
|
||||
|
@ -160,10 +161,6 @@ def main(opt):
|
|||
t1 = time.time()
|
||||
print(s)
|
||||
|
||||
# Write epoch results
|
||||
with open('results.txt', 'a') as file:
|
||||
file.write(s + '\n')
|
||||
|
||||
# Update best loss
|
||||
loss_per_target = rloss['loss'] / rloss['nT']
|
||||
if loss_per_target < best_loss:
|
||||
|
@ -184,6 +181,14 @@ def main(opt):
|
|||
if (epoch > 0) & (epoch % 5 == 0):
|
||||
os.system('cp weights/latest.pt weights/backup' + str(epoch) + '.pt')
|
||||
|
||||
# Calculate mAP
|
||||
test.opt.weights_path = 'weights/latest.pt'
|
||||
mAP = test.main(test.opt)
|
||||
|
||||
# Write epoch results
|
||||
with open('results.txt', 'a') as file:
|
||||
file.write(s + '%11.3g' % mAP + '\n')
|
||||
|
||||
# Save final model
|
||||
dt = time.time() - t0
|
||||
print('Finished %g epochs in %.2fs (%.2fs/epoch)' % (epoch, dt, dt / (epoch + 1)))
|
||||
|
|
|
@ -11,7 +11,7 @@ gsutil cp gs://ultralytics/yolov3.pt yolov3/weights
|
|||
python3 detect.py
|
||||
|
||||
# Test
|
||||
python3 test.py -img_size 416 -weights_path weights/latest.pt -conf_thres 0.1
|
||||
python3 test.py -img_size 416 -weights_path weights/latest.pt -conf_thres 0.5
|
||||
|
||||
# Download and Test
|
||||
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 && cd yolov3
|
||||
|
|
|
@ -435,7 +435,7 @@ def plot_results():
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(16, 8))
|
||||
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall']
|
||||
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
|
||||
for f in ('results.txt',):
|
||||
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T
|
||||
for i in range(9):
|
||||
|
|
Loading…
Reference in New Issue