updates
This commit is contained in:
parent
d2589fc5f7
commit
d09db54cb0
2
train.py
2
train.py
|
@ -138,7 +138,7 @@ def train(
|
|||
# plt.savefig('LR.png', dpi=300)
|
||||
|
||||
# Dataset
|
||||
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, image_weights=False)
|
||||
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, cache=True)
|
||||
|
||||
# Initialize distributed training
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
|
|
@ -130,7 +130,7 @@ class LoadWebcam: # for inference
|
|||
|
||||
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False):
|
||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False):
|
||||
with open(path, 'r') as f:
|
||||
img_files = f.read().splitlines()
|
||||
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
|
||||
|
@ -185,7 +185,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
self.batch = bi # batch index of image
|
||||
|
||||
# Preload images
|
||||
if n < 1001: # preload all images into memory if possible
|
||||
if cache & n < 1001: # preload all images into memory if possible
|
||||
self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')]
|
||||
|
||||
# Preload labels (required for weighted CE training)
|
||||
|
|
|
@ -71,6 +71,7 @@ gsutil cp -r gs://sm4/supermarket2 . # dataset from bucket
|
|||
rm -rf darknet && git clone https://github.com/AlexeyAB/darknet && cd darknet && wget -c https://pjreddie.com/media/files/darknet53.conv.74 # sudo apt install libopencv-dev && make
|
||||
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg darknet53.conv.74 -map -dont_show # train
|
||||
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg backup/yolov3-spp-sm2-1cls_last.weights # resume
|
||||
python3 train.py --data ../supermarket2/supermarket2.data --cfg cfg/yolov3-spp-sm2-1cls.cfg # test
|
||||
python3 test.py --data ../supermarket2/supermarket2.data --weights ../darknet/backup/yolov3-spp-sm2-1cls_3000.weights # test
|
||||
gsutil cp -r backup/*.weights gs://sm4/weights # weights to bucket
|
||||
|
||||
|
|
Loading…
Reference in New Issue