updates
This commit is contained in:
parent
b965b6e9b7
commit
3fb6cc8161
|
@ -11,11 +11,10 @@
|
||||||
*.weights
|
*.weights
|
||||||
*.pt
|
*.pt
|
||||||
*.weights
|
*.weights
|
||||||
*results.txt
|
|
||||||
!zidane_result.jpg
|
!zidane_result.jpg
|
||||||
#!coco_training_loss.png
|
!coco_training_loss.png
|
||||||
|
|
||||||
|
|
||||||
|
results.txt
|
||||||
temp-plot.html
|
temp-plot.html
|
||||||
|
|
||||||
# MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
|
# MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
|
||||||
|
|
2
train.py
2
train.py
|
@ -178,7 +178,7 @@ def main(opt):
|
||||||
os.system('cp checkpoints/latest.pt checkpoints/best.pt')
|
os.system('cp checkpoints/latest.pt checkpoints/best.pt')
|
||||||
|
|
||||||
# Save backup checkpoint
|
# Save backup checkpoint
|
||||||
if (epoch > 0) & (epoch % 100 == 0):
|
if (epoch > 0) & (epoch % 10 == 0):
|
||||||
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
|
os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')
|
||||||
|
|
||||||
# Save final model
|
# Save final model
|
||||||
|
|
|
@ -61,14 +61,15 @@ class ImageFolder(): # for eval-only
|
||||||
class ListDataset(): # for training
|
class ListDataset(): # for training
|
||||||
def __init__(self, path, batch_size=1, img_size=608):
|
def __init__(self, path, batch_size=1, img_size=608):
|
||||||
self.path = path
|
self.path = path
|
||||||
#self.img_files = sorted(glob.glob('%s/*.*' % path))
|
# self.img_files = sorted(glob.glob('%s/*.*' % path))
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
self.img_files = file.readlines()
|
self.img_files = file.readlines()
|
||||||
|
|
||||||
if platform == 'darwin': # macos
|
if platform == 'darwin': # macos
|
||||||
self.img_files = [path.replace('\n', '').replace('/images','/Users/glennjocher/Downloads/DATA/coco/images') for path in self.img_files]
|
self.img_files = [path.replace('\n', '').replace('/images', '/Users/glennjocher/Downloads/DATA/coco/images')
|
||||||
|
for path in self.img_files]
|
||||||
else:
|
else:
|
||||||
self.img_files = [path.replace('\n', '').replace('/images','../coco/images') for path in self.img_files]
|
self.img_files = [path.replace('\n', '').replace('/images', '../coco/images') for path in self.img_files]
|
||||||
|
|
||||||
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
|
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
|
||||||
self.img_files]
|
self.img_files]
|
||||||
|
@ -77,7 +78,7 @@ class ListDataset(): # for training
|
||||||
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
#assert self.nB > 0, 'No images found in path %s' % path
|
# assert self.nB > 0, 'No images found in path %s' % path
|
||||||
self.height = img_size
|
self.height = img_size
|
||||||
|
|
||||||
# RGB normalization values
|
# RGB normalization values
|
||||||
|
@ -86,8 +87,8 @@ class ListDataset(): # for training
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.count = -1
|
self.count = -1
|
||||||
# self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
|
self.shuffled_vector = np.random.permutation(self.nF) # shuffled vector
|
||||||
self.shuffled_vector = np.arange(self.nF)
|
# self.shuffled_vector = np.arange(self.nF) # not shuffled
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
|
@ -110,7 +111,7 @@ class ListDataset(): # for training
|
||||||
if img is None:
|
if img is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
augment_hsv = True
|
augment_hsv = False
|
||||||
if augment_hsv:
|
if augment_hsv:
|
||||||
# SV augmentation by 50%
|
# SV augmentation by 50%
|
||||||
fraction = 0.50
|
fraction = 0.50
|
||||||
|
@ -149,7 +150,7 @@ class ListDataset(): # for training
|
||||||
labels = np.array([])
|
labels = np.array([])
|
||||||
|
|
||||||
# Augment image and labels
|
# Augment image and labels
|
||||||
img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
|
# img, labels, M = random_affine(img, targets=labels, degrees=(-10, 10), translate=(0.2, 0.2), scale=(0.8, 1.2)) # RGB
|
||||||
|
|
||||||
plotFlag = False
|
plotFlag = False
|
||||||
if plotFlag:
|
if plotFlag:
|
||||||
|
@ -163,7 +164,7 @@ class ListDataset(): # for training
|
||||||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
|
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5].copy()) / height
|
||||||
|
|
||||||
# random left-right flip
|
# random left-right flip
|
||||||
lr_flip = True
|
lr_flip = False
|
||||||
if lr_flip & (random.random() > 0.5):
|
if lr_flip & (random.random() > 0.5):
|
||||||
img = np.fliplr(img)
|
img = np.fliplr(img)
|
||||||
if nL > 0:
|
if nL > 0:
|
||||||
|
|
Loading…
Reference in New Issue