updates
This commit is contained in:
parent
efe86b0c4c
commit
78e9bf60d2
1
train.py
1
train.py
|
@ -185,6 +185,7 @@ def train():
|
||||||
hyp=hyp, # augmentation hyperparameters
|
hyp=hyp, # augmentation hyperparameters
|
||||||
rect=opt.rect, # rectangular training
|
rect=opt.rect, # rectangular training
|
||||||
image_weights=opt.img_weights,
|
image_weights=opt.img_weights,
|
||||||
|
cache_labels=True if epochs > 10 else False,
|
||||||
cache_images=False if opt.prebias else opt.cache_images)
|
cache_images=False if opt.prebias else opt.cache_images)
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
|
|
|
@ -246,7 +246,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
||||||
|
|
||||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
|
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
|
||||||
cache_images=False):
|
cache_labels=False, cache_images=False):
|
||||||
path = str(Path(path)) # os-agnostic
|
path = str(Path(path)) # os-agnostic
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
|
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
|
||||||
|
@ -305,7 +305,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
# Preload labels (required for weighted CE training)
|
# Preload labels (required for weighted CE training)
|
||||||
self.imgs = [None] * n
|
self.imgs = [None] * n
|
||||||
self.labels = [None] * n
|
self.labels = [None] * n
|
||||||
if augment or image_weights: # cache labels for faster training
|
if cache_labels or image_weights: # cache labels for faster training
|
||||||
self.labels = [np.zeros((0, 5))] * n
|
self.labels = [np.zeros((0, 5))] * n
|
||||||
extract_bounding_boxes = False
|
extract_bounding_boxes = False
|
||||||
pbar = tqdm(self.label_files, desc='Reading labels')
|
pbar = tqdm(self.label_files, desc='Reading labels')
|
||||||
|
|
Loading…
Reference in New Issue