This commit is contained in:
Glenn Jocher 2018-08-26 15:40:07 +02:00
parent b965b6e9b7
commit 3fb6cc8161
3 changed files with 13 additions and 13 deletions

5
.gitignore vendored
View File

@ -11,11 +11,10 @@
*.weights *.weights
*.pt *.pt
*.weights *.weights
*results.txt
!zidane_result.jpg !zidane_result.jpg
#!coco_training_loss.png !coco_training_loss.png
results.txt
temp-plot.html temp-plot.html
# MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------

View File

@ -178,7 +178,7 @@ def main(opt):
os.system('cp checkpoints/latest.pt checkpoints/best.pt') os.system('cp checkpoints/latest.pt checkpoints/best.pt')
# Save backup checkpoint # Save backup checkpoint
if (epoch > 0) & (epoch % 100 == 0): if (epoch > 0) & (epoch % 10 == 0):
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt') os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
# Save final model # Save final model

View File

@ -66,7 +66,8 @@ class ListDataset(): # for training
self.img_files = file.readlines() self.img_files = file.readlines()
if platform == 'darwin': # macos if platform == 'darwin': # macos
self.img_files = [path.replace('\n', '').replace('/images','/Users/glennjocher/Downloads/DATA/coco/images') for path in self.img_files] self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
for path in self.img_files]
else: else:
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files] self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
@ -86,8 +87,8 @@ class ListDataset(): # for training
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
# self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
self.shuffled_vector = np.arange(self.nF) # self.shuffled_vector = np.arange(self.nF) # not shuffled
return self return self
def __next__(self): def __next__(self):
@ -110,7 +111,7 @@ class ListDataset(): # for training
if img is None: if img is None:
continue continue
augment_hsv = True augment_hsv = False
if augment_hsv: if augment_hsv:
# SV augmentation by 50% # SV augmentation by 50%
fraction = 0.50 fraction = 0.50
@ -149,7 +150,7 @@ class ListDataset(): # for training
labels = np.array([]) labels = np.array([])
# Augment image and labels # Augment image and labels
img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB # img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
plotFlag = False plotFlag = False
if plotFlag: if plotFlag:
@ -163,7 +164,7 @@ class ListDataset(): # for training
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
# random left-right flip # random left-right flip
lr_flip = True lr_flip = False
if lr_flip & (random.random() > 0.5): if lr_flip & (random.random() > 0.5):
img = np.fliplr(img) img = np.fliplr(img)
if nL > 0: if nL > 0: