This commit is contained in:
Glenn Jocher 2019-05-21 17:37:34 +02:00
parent 520e58aa05
commit 19c7697434
4 changed files with 31 additions and 12 deletions

View File

@ -148,7 +148,7 @@ def test(
# Print results
pf = '%20s' + '%10.3g' * 6 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1), end='\n\n')
print(pf % ('all', seen, nt.sum(), mp, mr, map, mf1))
# Print results per class
if nc > 1 and len(stats):

View File

@ -75,7 +75,7 @@ def train(
device = torch_utils.select_device()
if multi_scale:
img_size = 608 # initiate with maximum multi_scale size
img_size = round((img_size / 32) * 1.5) * 32 # initiate with maximum multi_scale size
opt.num_workers = 0 # bug https://github.com/ultralytics/yolov3/issues/174
else:
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
@ -138,7 +138,13 @@ def train(
# plt.savefig('LR.png', dpi=300)
# Dataset
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, rect=False, cache=True)
dataset = LoadImagesAndLabels(train_path,
img_size,
batch_size,
augment=True,
rect=False,
cache=True,
multi_scale=multi_scale)
# Initialize distributed training
if torch.cuda.device_count() > 1:

View File

@ -130,14 +130,19 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False):
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False,
multi_scale=False):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
n = len(self.img_files)
self.n = n
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
assert n > 0, 'No images found in %s' % path
self.n = n
self.batch = bi # batch index of image
self.img_size = img_size
self.augment = augment
self.image_weights = image_weights
@ -148,11 +153,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
replace('.bmp', '.txt').
replace('.png', '.txt') for x in self.img_files]
if multi_scale:
s = img_size / 32
self.multi_scale = ((np.linspace(0.5, 1.5, nb) * s).round().astype(np.int) * 32)
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
if self.rect:
from PIL import Image
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
# Read image shapes
sp = 'data' + os.sep + path.replace('.txt', '.shapes').split(os.sep)[-1] # shapefile path
@ -182,7 +189,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
self.batch = bi # batch index of image
# Preload images
if cache and (n < 1001): # preload all images into memory if possible
@ -207,6 +213,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
def __len__(self):
return len(self.img_files)
# def __iter__(self):
# self.count = -1
# print('ran dataset iter')
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
# return self
def __getitem__(self, index):
if self.image_weights:
index = self.indices[index]
@ -242,10 +254,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Letterbox
h, w, _ = img.shape
if self.rect:
new_shape = self.batch_shapes[self.batch[index]]
img, ratio, padw, padh = letterbox(img, new_shape=new_shape, mode='rect')
shape = self.batch_shapes[self.batch[index]]
img, ratio, padw, padh = letterbox(img, new_shape=shape, mode='rect')
else:
img, ratio, padw, padh = letterbox(img, new_shape=self.img_size, mode='square')
shape = int(self.multi_scale[self.batch[index]]) if hasattr(self, 'multi_scale') else self.img_size
img, ratio, padw, padh = letterbox(img, new_shape=shape, mode='square')
# Load labels
labels = []

View File

@ -72,7 +72,7 @@ rm -rf darknet && git clone https://github.com/AlexeyAB/darknet && cd darknet &&
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg darknet53.conv.74 -map -dont_show # train
./darknet detector train ../supermarket2/supermarket2.data cfg/yolov3-spp-sm2-1cls.cfg backup/yolov3-spp-sm2-1cls_last.weights # resume
python3 train.py --data ../supermarket2/supermarket2.data --cfg cfg/yolov3-spp-sm2-1cls.cfg # test
python3 test.py --data ../supermarket2/supermarket2.data --weights ../darknet/backup/yolov3-spp-sm2-1cls_3000.weights # test
python3 test.py --data ../supermarket2/supermarket2.data --weights ../darknet/backup/yolov3-spp-sm2-1cls_5000.weights --cfg cfg/yolov3-spp-sm2-1cls.cfg # test
gsutil cp -r backup/*.weights gs://sm4/weights # weights to bucket
# Debug/Development