This commit is contained in:
Glenn Jocher 2018-08-26 15:40:07 +02:00
parent b965b6e9b7
commit 3fb6cc8161
3 changed files with 13 additions and 13 deletions

5
.gitignore vendored
View File

@ -11,11 +11,10 @@
*.weights *.weights
*.pt *.pt
*.weights *.weights
*results.txt
!zidane_result.jpg !zidane_result.jpg
#!coco_training_loss.png !coco_training_loss.png
results.txt
temp-plot.html temp-plot.html
# MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------

View File

@ -178,7 +178,7 @@ def main(opt):
os.system('cp checkpoints/latest.pt checkpoints/best.pt') os.system('cp checkpoints/latest.pt checkpoints/best.pt')
# Save backup checkpoint # Save backup checkpoint
if (epoch > 0) & (epoch % 100 == 0): if (epoch > 0) & (epoch % 10 == 0):
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt') os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
# Save final model # Save final model

View File

@ -61,14 +61,15 @@ class ImageFolder(): # for eval-only
class ListDataset(): # for training class ListDataset(): # for training
def __init__(self, path, batch_size=1, img_size=608): def __init__(self, path, batch_size=1, img_size=608):
self.path = path self.path = path
#self.img_files = sorted(glob.glob('%s/*.*' % path)) # self.img_files = sorted(glob.glob('%s/*.*' % path))
with open(path, 'r') as file: with open(path, 'r') as file:
self.img_files = file.readlines() self.img_files = file.readlines()
if platform == 'darwin': # macos if platform == 'darwin': # macos
self.img_files = [path.replace('\n', '').replace('/images','/Users/glennjocher/Downloads/DATA/coco/images') for path in self.img_files] self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
for path in self.img_files]
else: else:
self.img_files = [path.replace('\n', '').replace('/images','../coco/images') for path in self.img_files] self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
self.img_files] self.img_files]
@ -77,7 +78,7 @@ class ListDataset(): # for training
self.nB = math.ceil(self.nF / batch_size) # number of batches self.nB = math.ceil(self.nF / batch_size) # number of batches
self.batch_size = batch_size self.batch_size = batch_size
#assert self.nB > 0, 'No images found in path %s' % path # assert self.nB > 0, 'No images found in path %s' % path
self.height = img_size self.height = img_size
# RGB normalization values # RGB normalization values
@ -86,8 +87,8 @@ class ListDataset(): # for training
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
# self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
self.shuffled_vector = np.arange(self.nF) # self.shuffled_vector = np.arange(self.nF) # not shuffled
return self return self
def __next__(self): def __next__(self):
@ -110,7 +111,7 @@ class ListDataset(): # for training
if img is None: if img is None:
continue continue
augment_hsv = True augment_hsv = False
if augment_hsv: if augment_hsv:
# SV augmentation by 50% # SV augmentation by 50%
fraction = 0.50 fraction = 0.50
@ -149,7 +150,7 @@ class ListDataset(): # for training
labels = np.array([]) labels = np.array([])
# Augment image and labels # Augment image and labels
img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB # img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
plotFlag = False plotFlag = False
if plotFlag: if plotFlag:
@ -163,7 +164,7 @@ class ListDataset(): # for training
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
# random left-right flip # random left-right flip
lr_flip = True lr_flip = False
if lr_flip & (random.random() > 0.5): if lr_flip & (random.random() > 0.5):
img = np.fliplr(img) img = np.fliplr(img)
if nL > 0: if nL > 0: