This commit is contained in:
Glenn Jocher 2019-05-21 16:03:29 +02:00
parent d2589fc5f7
commit d09db54cb0
3 changed files with 4 additions and 3 deletions

View File

@ -138,7 +138,7 @@ def train(
# plt.savefig('LR.png', dpi=300)
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, image_weights=False)
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, cache=True)
# Initialize distributed training
if torch.cuda.device_count() > 1:

View File

@ -130,7 +130,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False):
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
@ -185,7 +185,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.batch = bi # batch index of image
# Preload images
if n < 1001: # preload all images into memory if possible
if cache & n < 1001: # preload all images into memory if possible
self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')]
# Preload labels (required for weighted CE training)

View File

@ -71,6 +71,7 @@ gsutil cp -r gs://sm4/supermarket2 . # dataset from bucket
rm -rf darknet && git clone https://github.com/AlexeyAB/darknet && cd darknet && wget -c https://pjreddie.com/media/files/darknet53.conv.74 # sudo apt install libopencv-dev && make
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg darknet53.conv.74 -map -dont_show # train
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg backup/yolov3-spp-sm2-1cls_last.weights # resume
python3 train.py --data ../supermarket2/supermarket2.data --cfg cfg/yolov3-spp-sm2-1cls.cfg # test
python3 test.py --data ../supermarket2/supermarket2.data --weights ../darknet/backup/yolov3-spp-sm2-1cls_3000.weights # test
gsutil cp -r backup/*.weights gs://sm4/weights # weights to bucket