This commit is contained in:
Glenn Jocher 2018-08-26 15:40:07 +02:00
parent b965b6e9b7
commit 3fb6cc8161
3 changed files with 13 additions and 13 deletions

5
.gitignore vendored
View File

@ -11,11 +11,10 @@
*.weights
*.pt
*.weights
*results.txt
!zidane_result.jpg
#!coco_training_loss.png
!coco_training_loss.png
results.txt
temp-plot.html
# MATLAB GitIgnore -----------------------------------------------------------------------------------------------------

View File

@ -178,7 +178,7 @@ def main(opt):
os.system('cp checkpoints/latest.pt checkpoints/best.pt')
# Save backup checkpoint
if (epoch > 0) & (epoch % 100 == 0):
if (epoch > 0) & (epoch % 10 == 0):
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
# Save final model

View File

@ -61,14 +61,15 @@ class ImageFolder(): # for eval-only
class ListDataset(): # for training
def __init__(self, path, batch_size=1, img_size=608):
self.path = path
#self.img_files = sorted(glob.glob('%s/*.*' % path))
# self.img_files = sorted(glob.glob('%s/*.*' % path))
with open(path, 'r') as file:
self.img_files = file.readlines()
if platform == 'darwin': # macos
self.img_files = [path.replace('\n', '').replace('/images','/Users/glennjocher/Downloads/DATA/coco/images') for path in self.img_files]
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
for path in self.img_files]
else:
self.img_files = [path.replace('\n', '').replace('/images','../coco/images') for path in self.img_files]
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
self.img_files]
@ -77,7 +78,7 @@ class ListDataset(): # for training
self.nB = math.ceil(self.nF / batch_size) # number of batches
self.batch_size = batch_size
#assert self.nB > 0, 'No images found in path %s' % path
# assert self.nB > 0, 'No images found in path %s' % path
self.height = img_size
# RGB normalization values
@ -86,8 +87,8 @@ class ListDataset(): # for training
def __iter__(self):
self.count = -1
# self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
self.shuffled_vector = np.arange(self.nF)
self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
# self.shuffled_vector = np.arange(self.nF) # not shuffled
return self
def __next__(self):
@ -110,7 +111,7 @@ class ListDataset(): # for training
if img is None:
continue
augment_hsv = True
augment_hsv = False
if augment_hsv:
# SV augmentation by 50%
fraction = 0.50
@ -149,7 +150,7 @@ class ListDataset(): # for training
labels = np.array([])
# Augment image and labels
img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
# img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
plotFlag = False
if plotFlag:
@ -163,7 +164,7 @@ class ListDataset(): # for training
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
# random left-right flip
lr_flip = True
lr_flip = False
if lr_flip & (random.random() > 0.5):
img = np.fliplr(img)
if nL > 0: