updates
This commit is contained in:
parent
6b222df35d
commit
2244c72a1b
|
@ -187,6 +187,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
|
|
||||||
# Preload labels (required for weighted CE training)
|
# Preload labels (required for weighted CE training)
|
||||||
self.imgs = [None] * n
|
self.imgs = [None] * n
|
||||||
|
self.labels = [None] * n
|
||||||
|
preload_labels = False
|
||||||
|
if preload_labels:
|
||||||
self.labels = [np.zeros((0, 5))] * n
|
self.labels = [np.zeros((0, 5))] * n
|
||||||
iter = tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files
|
iter = tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files
|
||||||
extract_bounding_boxes = False
|
extract_bounding_boxes = False
|
||||||
|
@ -206,7 +209,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
img = cv2.imread(str(p))
|
img = cv2.imread(str(p))
|
||||||
h, w, _ = img.shape
|
h, w, _ = img.shape
|
||||||
for j, x in enumerate(l):
|
for j, x in enumerate(l):
|
||||||
f = '%s%sclassification%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
|
f = '%s%sclassification%s%g_%g_%s' % (
|
||||||
|
p.parent.parent, os.sep, os.sep, x[0], j, p.name)
|
||||||
if not os.path.exists(Path(f).parent):
|
if not os.path.exists(Path(f).parent):
|
||||||
os.makedirs(Path(f).parent) # make new output folder
|
os.makedirs(Path(f).parent) # make new output folder
|
||||||
box = xywh2xyxy(x[1:].reshape(-1, 4)).ravel()
|
box = xywh2xyxy(x[1:].reshape(-1, 4)).ravel()
|
||||||
|
@ -215,7 +219,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
int(box[0] * w):int(box[2] * w)])
|
int(box[0] * w):int(box[2] * w)])
|
||||||
if not result:
|
if not result:
|
||||||
print('stop')
|
print('stop')
|
||||||
|
|
||||||
except:
|
except:
|
||||||
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
|
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
|
||||||
assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
|
assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
|
||||||
|
@ -274,9 +277,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
# Load labels
|
# Load labels
|
||||||
labels = []
|
labels = []
|
||||||
if os.path.isfile(label_path):
|
if os.path.isfile(label_path):
|
||||||
# with open(label_path, 'r') as f:
|
|
||||||
# x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
|
||||||
x = self.labels[index]
|
x = self.labels[index]
|
||||||
|
if x is None: # labels not preloaded
|
||||||
|
with open(label_path, 'r') as f:
|
||||||
|
x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
||||||
|
self.labels[index] = x # save for next time
|
||||||
|
|
||||||
if x.size > 0:
|
if x.size > 0:
|
||||||
# Normalized xywh to pixel xyxy format
|
# Normalized xywh to pixel xyxy format
|
||||||
labels = x.copy()
|
labels = x.copy()
|
||||||
|
|
Loading…
Reference in New Issue