updates
This commit is contained in:
parent
b965b6e9b7
commit
3fb6cc8161
|
@ -11,11 +11,10 @@
|
|||
*.weights
|
||||
*.pt
|
||||
*.weights
|
||||
*results.txt
|
||||
!zidane_result.jpg
|
||||
#!coco_training_loss.png
|
||||
|
||||
!coco_training_loss.png
|
||||
|
||||
results.txt
|
||||
temp-plot.html
|
||||
|
||||
# MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
|
||||
|
|
2
train.py
2
train.py
|
@ -178,7 +178,7 @@ def main(opt):
|
|||
os.system('cp checkpoints/latest.pt checkpoints/best.pt')
|
||||
|
||||
# Save backup checkpoint
|
||||
if (epoch > 0) & (epoch % 100 == 0):
|
||||
if (epoch > 0) & (epoch % 10 == 0):
|
||||
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
|
||||
|
||||
# Save final model
|
||||
|
|
|
@ -66,7 +66,8 @@ class ListDataset(): # for training
|
|||
self.img_files = file.readlines()
|
||||
|
||||
if platform == 'darwin': # macos
|
||||
self.img_files = [path.replace('\n', '').replace('/images','/Users/glennjocher/Downloads/DATA/coco/images') for path in self.img_files]
|
||||
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
|
||||
for path in self.img_files]
|
||||
else:
|
||||
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
|
||||
|
||||
|
@ -86,8 +87,8 @@ class ListDataset(): # for training
|
|||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
# self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
|
||||
self.shuffled_vector = np.arange(self.nF)
|
||||
self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
|
||||
# self.shuffled_vector = np.arange(self.nF) # not shuffled
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
|
@ -110,7 +111,7 @@ class ListDataset(): # for training
|
|||
if img is None:
|
||||
continue
|
||||
|
||||
augment_hsv = True
|
||||
augment_hsv = False
|
||||
if augment_hsv:
|
||||
# SV augmentation by 50%
|
||||
fraction = 0.50
|
||||
|
@ -149,7 +150,7 @@ class ListDataset(): # for training
|
|||
labels = np.array([])
|
||||
|
||||
# Augment image and labels
|
||||
img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
|
||||
# img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
|
||||
|
||||
plotFlag = False
|
||||
if plotFlag:
|
||||
|
@ -163,7 +164,7 @@ class ListDataset(): # for training
|
|||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
|
||||
|
||||
# random left-right flip
|
||||
lr_flip = True
|
||||
lr_flip = False
|
||||
if lr_flip & (random.random() > 0.5):
|
||||
img = np.fliplr(img)
|
||||
if nL > 0:
|
||||
|
|
Loading…
Reference in New Issue