This commit is contained in:
Glenn Jocher 2019-09-16 19:53:44 +02:00
parent efe86b0c4c
commit 78e9bf60d2
2 changed files with 3 additions and 2 deletions

View File

@ -185,6 +185,7 @@ def train():
hyp=hyp, # augmentation hyperparameters
rect=opt.rect, # rectangular training
image_weights=opt.img_weights,
cache_labels=True if epochs > 10 else False,
cache_images=False if opt.prebias else opt.cache_images)
# Dataloader

View File

@ -246,7 +246,7 @@ class LoadStreams: # multiple IP or RTSP cameras
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
cache_images=False):
cache_labels=False, cache_images=False):
path = str(Path(path)) # os-agnostic
with open(path, 'r') as f:
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
@ -305,7 +305,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Preload labels (required for weighted CE training)
self.imgs = [None] * n
self.labels = [None] * n
if augment or image_weights: # cache labels for faster training
if cache_labels or image_weights: # cache labels for faster training
self.labels = [np.zeros((0, 5))] * n
extract_bounding_boxes = False
pbar = tqdm(self.label_files, desc='Reading labels')