diff --git a/train.py b/train.py index 0f939769..2b93431e 100644 --- a/train.py +++ b/train.py @@ -138,7 +138,7 @@ def train( # plt.savefig('LR.png', dpi=300) # Dataset - dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, image_weights=False) + dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, cache=True) # Initialize distributed training if torch.cuda.device_count() > 1: diff --git a/utils/datasets.py b/utils/datasets.py index 5ebbc4de..c4395ccd 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -130,7 +130,7 @@ class LoadWebcam: # for inference class LoadImagesAndLabels(Dataset): # for training/testing - def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False): + def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False): with open(path, 'r') as f: img_files = f.read().splitlines() self.img_files = list(filter(lambda x: len(x) > 0, img_files)) @@ -185,7 +185,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing self.batch = bi # batch index of image # Preload images - if n < 1001: # preload all images into memory if possible + if cache & n < 1001: # preload all images into memory if possible self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')] # Preload labels (required for weighted CE training) diff --git a/utils/gcp.sh b/utils/gcp.sh index ddb6eeed..7bd6e123 100755 --- a/utils/gcp.sh +++ b/utils/gcp.sh @@ -71,6 +71,7 @@ gsutil cp -r gs://sm4/supermarket2 . # dataset from bucket rm -rf darknet && git clone https://github.com/AlexeyAB/darknet && cd darknet && wget -c https://pjreddie.com/media/files/darknet53.conv.74 # sudo apt install libopencv-dev && make ./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg darknet53.conv.74 -map -dont_show # train ./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg backup/yolov3-spp-sm2-1cls_last.weights # resume +python3 train.py --data ../supermarket2/supermarket2.data --cfg cfg/yolov3-spp-sm2-1cls.cfg # test python3 test.py --data ../supermarket2/supermarket2.data --weights ../darknet/backup/yolov3-spp-sm2-1cls_3000.weights # test gsutil cp -r backup/*.weights gs://sm4/weights # weights to bucket