updates
This commit is contained in:
parent
efe86b0c4c
commit
78e9bf60d2
1
train.py
1
train.py
|
@ -185,6 +185,7 @@ def train():
|
|||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=opt.rect, # rectangular training
|
||||
image_weights=opt.img_weights,
|
||||
cache_labels=True if epochs > 10 else False,
|
||||
cache_images=False if opt.prebias else opt.cache_images)
|
||||
|
||||
# Dataloader
|
||||
|
|
|
@ -246,7 +246,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|||
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
|
||||
cache_images=False):
|
||||
cache_labels=False, cache_images=False):
|
||||
path = str(Path(path)) # os-agnostic
|
||||
with open(path, 'r') as f:
|
||||
self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
|
||||
|
@ -305,7 +305,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
# Preload labels (required for weighted CE training)
|
||||
self.imgs = [None] * n
|
||||
self.labels = [None] * n
|
||||
if augment or image_weights: # cache labels for faster training
|
||||
if cache_labels or image_weights: # cache labels for faster training
|
||||
self.labels = [np.zeros((0, 5))] * n
|
||||
extract_bounding_boxes = False
|
||||
pbar = tqdm(self.label_files, desc='Reading labels')
|
||||
|
|
Loading…
Reference in New Issue