diff --git a/train.py b/train.py index c3763fdc..57ce04fd 100644 --- a/train.py +++ b/train.py @@ -185,6 +185,7 @@ def train(): hyp=hyp, # augmentation hyperparameters rect=opt.rect, # rectangular training image_weights=opt.img_weights, + cache_labels=True if epochs > 10 else False, cache_images=False if opt.prebias else opt.cache_images) # Dataloader diff --git a/utils/datasets.py b/utils/datasets.py index 6c94b375..0116d38e 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -246,7 +246,7 @@ class LoadStreams: # multiple IP or RTSP cameras class LoadImagesAndLabels(Dataset): # for training/testing def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, - cache_images=False): + cache_labels=False, cache_images=False): path = str(Path(path)) # os-agnostic with open(path, 'r') as f: self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic @@ -305,7 +305,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing # Preload labels (required for weighted CE training) self.imgs = [None] * n self.labels = [None] * n - if augment or image_weights: # cache labels for faster training + if cache_labels or image_weights: # cache labels for faster training self.labels = [np.zeros((0, 5))] * n extract_bounding_boxes = False pbar = tqdm(self.label_files, desc='Reading labels')