updates
This commit is contained in:
parent
056976b4fc
commit
f43170817c
4
train.py
4
train.py
|
@ -187,7 +187,8 @@ def train(cfg,
|
|||
augment=True,
|
||||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=opt.rect, # rectangular training
|
||||
image_weights=opt.img_weights)
|
||||
image_weights=opt.img_weights,
|
||||
cache_images=opt.cache_images)
|
||||
|
||||
# Dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
|
@ -352,6 +353,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
||||
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
||||
parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
|
||||
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class LoadWebcam: # for inference
|
|||
|
||||
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False):
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, cache_images=False):
|
||||
path = str(Path(path)) # os-agnostic
|
||||
with open(path, 'r') as f:
|
||||
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
|
||||
|
@ -254,7 +254,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
assert nf > 0, 'No labels found. Recommend correcting image and label paths.'
|
||||
|
||||
# Cache images into memory for faster training (~5GB)
|
||||
cache_images = False
|
||||
if cache_images and augment: # if training
|
||||
for i in tqdm(range(min(len(self.img_files), 10000)), desc='Reading images'): # max 10k images
|
||||
img_path = self.img_files[i]
|
||||
|
|
Loading…
Reference in New Issue