car-detection-bayes/our_scripts/generate_txt_from_xml.py

152 lines
5.1 KiB
Python

# - *- coding: utf- 8 - *-
import sys
import os
import re
import xml.etree.ElementTree as ET
from glob import glob
from os.path import join
from pathlib import Path
# This should just be a folder of xmls
annotations = sys.argv[1]
# Then you have a folder of txts.
modified_annotations = sys.argv[2]
def convert(size, box):
dw = 1. / (size[0])
dh = 1. / (size[1])
x = (box[0] + box[1]) / 2.0 - 1
y = (box[2] + box[3]) / 2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = round(x * dw, 4)
w = round(w * dw, 4)
y = round(y * dh, 4)
h = round(h * dh, 4)
return (x, y, w, h)
def map_class_name_to_id(class_name, xml_document, class_distribution):
if class_name in ['1. rower']:
class_distribution[0] += 1
return 0
elif class_name in ['2. motocykl']:
class_distribution[1] += 1
return 1
elif class_name in ['3. osobowy']:
class_distribution[2] += 1
return 2
elif class_name in ['4. osobowy pickup']:
class_distribution[3] += 1
return 3
elif class_name in ['5. osobowy dostawczy']:
class_distribution[4] += 1
return 4
elif class_name in ['6. osobowy van 7-9']:
class_distribution[5] += 1
return 5
elif class_name in ['7. dostawczy blaszak']:
class_distribution[6] += 1
return 6
elif class_name in ['8. dostawczy zabudowany']:
class_distribution[7] += 1
return 7
elif class_name in ['9. dostawczy pickup (w tym pomoc drog.)']:
class_distribution[8] += 1
return 8
elif class_name in ['10. dostawczy VAN (osobowy)']:
class_distribution[9] += 1
return 9
elif class_name in ['11. autobus mały 10-24']:
return -1
elif class_name in ['12. autobus miejski']:
class_distribution[10] += 1
return 10
elif class_name in ['13. autobus turystyczny i inny']:
return -1
elif class_name in ['14. ciężarowy pow. 3,5t zabudowany']:
class_distribution[11] += 1
return 11
elif class_name in ['15. ciężarowy pow. 3,5t otwarty (w tym duży holownik)']:
class_distribution[12] += 1
return 12
elif class_name in ['16. ciężarowy pow. 3,5t inny (wanna, gruszka, dźwig itp.)']:
class_distribution[13] += 1
return 13
elif class_name in ['17. ciężarowy z widoczną przyczepą']:
return -1
elif class_name in ['18. ciągnik siodłowy z widoczną naczepą']:
class_distribution[14] += 1
return 14
elif class_name in ['19. ciągnik siodłowy bez naczepy']:
class_distribution[15] += 1
return 15
elif class_name in ['20. camper']:
class_distribution[15] += 1
return -1
elif class_name in ['22. ciągnik roliczy, koparka, spychacz']:
return -1
elif class_name in ['23. inne pojazdy silnikowe']:
return -1
elif class_name in ['24. przyczepa']:
class_distribution[16] += 1
return 16
elif class_name in ['25. BUS-karetka/policja']:
class_distribution[17] += 1
return 17
else:
raise Exception('Unknown Class ', xml_document, class_name)
def generate_txt_from_xml():
class_distribution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
filepaths = glob(annotations + '*.xml')
for filepath in filepaths:
txtpath = join(modified_annotations, re.sub(r"\.xml$", ".txt", os.path.basename(filepath)))
in_file = open(filepath, mode='r', encoding='utf-8')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
good_file = True
for obj in root.iter('object'):
#difficult = obj.find('difficult').text
class_label = obj.find('name').text
#if int(difficult) == 1:
# raise Exception("Difficult == 1")
cls_id = map_class_name_to_id(class_label, filepath, class_distribution)
if cls_id == -1 :
good_file = False
if not good_file :
print('File discarded.')
continue
Path(txtpath).touch()
out_file = open(txtpath, mode='w', encoding='utf-8')
for obj in root.iter('object'):
#difficult = obj.find('difficult').text
class_label = obj.find('name').text
#if int(difficult) == 1:
# raise Exception("Difficult == 1")
cls_id = map_class_name_to_id(class_label, filepath, class_distribution)
if cls_id != -1 :
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
print(class_distribution)
generate_txt_from_xml()