car-detection-bayes/our_scripts/generate-confusion-matrix.js

181 lines
6.3 KiB
JavaScript
Raw Normal View History

2020-06-28 12:37:21 +00:00
const fs = require('fs')
const path = require('path')
const detectionsDir = process.argv[2]
const labelsDir = process.argv[3]
const namesFile = process.argv[4]
const maxDistance = process.argv[5] || 0.1
const width = process.argv[6] || 1920
const height = process.argv[7] || 1080
/*
console.log("DETECTIONS DIRECTORY:", detectionsDir)
console.log("LABELS DIRECTORY:", labelsDir)
console.log("NAMES FILE:", namesFile)
console.log("WIDTH", width)
console.log("HEIGHT", height)
//*/
function parseDetections(detectionData) {
return detectionData
.split('\n')
.filter(x => !!x)
.map(line => line.split(' ').map(x => +x))
.map(a => ({
x: (a[0] + a[2]) / ( 2 * width ),
y: (a[1] + a[3]) / ( 2 * height ),
w: (a[2] - a[0]) / width,
h: (a[3] - a[1]) / height,
c: a[4],
p: a[5]
}))
}
function parseLabels(labelData) {
return labelData
.split('\n')
.filter(x => !!x)
.map(line => line.split(' ').map(x => +x))
.map(a => ({
x: a[1],
y: a[2],
w: a[3],
h: a[4],
c: a[0],
p: 1
}))
}
function findNearest(position, boxes) {
let dx = position.x - boxes[0].x
let dy = position.y - boxes[0].y
let bestBox = { ...boxes[0], d: Math.sqrt(dx * dx + dy * dy) }
for(let i = 1; i < boxes.length; i++) {
dx = position.x - boxes[i].x
dy = position.y - boxes[i].y
let distance = Math.sqrt(dx * dx + dy * dy)
if(distance < bestBox.d) {
bestBox = { ...boxes[i], d: distance }
}
}
return bestBox
}
function compare(labels, detections) {
2020-06-30 19:51:10 +00:00
// console.log("LABELS", JSON.stringify(labels))
// console.log("DETECTIONS", JSON.stringify(detections))
2020-06-28 12:37:21 +00:00
const results = {}
for(const label of labels) {
const detection = findNearest(label, detections)
if(detection.d > maxDistance) {
if(!results[label.c]) results[label.c] = {}
results[label.c]['n'] = + (results[label.c]['n'] || 0) + 1
} else {
if(!results[label.c]) results[label.c] = {}
results[label.c][detection.c] = (results[label.c][detection.c] || 0) + 1
}
}
for(const detection of detections) {
const label = findNearest(detection, labels)
if(label.d > maxDistance) {
results['n'] = results['n'] || {}
results['n'][detection.c] = + (results['n'][detection.c] || 0) + 1
}
}
return results
}
async function compareLabelsAndResults(txt) {
const detectionPath = path.resolve(detectionsDir, txt)
const basename = path.basename(txt.split('.')[0])
const labelPath = path.resolve(labelsDir, basename+'.txt')
const [detectionData, labelData] = await Promise.all([fs.promises.readFile(detectionPath, 'utf8'), fs.promises.readFile(labelPath, 'utf8')])
2020-06-30 19:51:10 +00:00
const detections = parseLabels(detectionData)
2020-06-28 12:37:21 +00:00
const labels = parseLabels(labelData)
return {
basename,
result: compare(labels, detections)
}
}
async function main() {
const names = (await fs.promises.readFile(namesFile, 'utf8')).split('\n').map(t=>t.trim())
names.n = '?'
names.sum = 'sum'
const files = await fs.promises.readdir(detectionsDir)
const txts = files.filter(p => path.extname(p) == '.txt')
//console.log("OUTPUT TXT FILES", txts.length)
const promises = txts.map(compareLabelsAndResults)
const compareResults = await Promise.all(promises)
await fs.promises.mkdir(path.resolve(detectionsDir, 'errors')).catch(e => {})
const summary = {}
const copyPromises = []
for(const result of compareResults) {
let errors = []
for(const c in result.result) {
if(!summary[c]) summary[c] = {}
for(const r in result.result[c]) {
summary[c][r] = (summary[c][r] || 0) + result.result[c][r]
if( c!=r ) errors.push([c, r, result.result[c][r]])
}
}
if(errors.length > 0) {
copyPromises.push(fs.promises.copyFile(
path.resolve(detectionsDir, result.basename + '.jpg'),
path.resolve(detectionsDir, 'errors', result.basename + '.jpg')
))
copyPromises.push(fs.promises.writeFile(
path.resolve(detectionsDir, 'errors', result.basename + '.tsv'),
errors.map(([c1,c2,cnt]) => [ names[c1], names[c2], cnt ].join('\t')).join('\n'), 'utf8'))
}
}
//console.log("S", summary)
let rows = Object.keys(summary).filter(k=>k!='n').sort().concat(['n'])
summary.sum = {}
for(const row of rows) {
if(!summary[row]) summary[row] = {}
const rowSum = rows.map(r => summary[row][r] || 0).reduce( (a, b) => a + b, 0)
const columnSum = rows.map(r => summary[r] && summary[r][row] || 0).reduce( (a, b) => a + b, 0)
summary[row].sum = rowSum
summary.sum[row] = columnSum
}
summaryRows = rows.concat(['sum'])
let tsvRows = []
tsvRows.push('Count:')
tsvRows.push([' ', ...(summaryRows.map(n=>names[n]))].join('\t'))
for(const row of summaryRows) {
const summaryPart = summary[row] || {}
tsvRows.push([ names[row], ...(summaryRows.map(r => summaryPart[r]))].join('\t'))
}
summaryRows.pop()
tsvRows.push('Fraction:')
tsvRows.push([' ', ...(summaryRows.map(n=>names[n]))].join('\t'))
for(const row of summaryRows) {
const summaryPart = summary[row] || {}
const sum = row != 'sum' ? summaryPart.sum : summary.sum[row]
tsvRows.push([ names[row], ...(summaryRows.map(r => summaryPart[r] && (summaryPart[r] / sum).toFixed(2)))].join('\t'))
}
const allLabeled = rows.slice(0, -1).map(r => summary.sum[r]).reduce((a, b) => a + b, 0)
const allDetected = rows.slice(0, -1).map(r => summary[r].sum).reduce((a, b) => a + b, 0)
const falseNegatives = rows.slice(0, -1).map(r => summary.n[r] || 0).reduce((a, b) => a + b, 0)
const falsePositives = rows.slice(0, -1).map(r => summary[r].n || 0).reduce((a, b) => a + b, 0)
const right = rows.slice(0, -1).map(r => summary[r][r] || 0).reduce((a, b) => a + b, 0)
const mistakes = rows.slice(0, -1).map(a => rows.slice(0, -1).map(b => (a!=b && summary[a][b]) || 0).reduce((a, b) => a + b, 0)).reduce((a, b) => a + b, 0)
console.log(`right:\t${right}\t${(right/allLabeled).toFixed(3)}`)
console.log(`false positives:\t${falsePositives}\t${(falsePositives/allLabeled).toFixed(3)}`)
console.log(`false negatives:\t${falseNegatives}\t${(falseNegatives/allLabeled).toFixed(3)}`)
console.log(`mistakes:\t${mistakes}\t${(mistakes/allLabeled).toFixed(3)}`)
console.log(`labeled:\t${allLabeled}`)
console.log(`detected:\t${allDetected}`)
let tsv = tsvRows.join('\n')
console.log(tsv)
await Promise.all(copyPromises)
}
main()