updates
This commit is contained in:
parent
68b9df4dd4
commit
001193b9c7
6
train.py
6
train.py
|
@ -55,7 +55,6 @@ hyp = {'xy': 0.2, # xy loss gain
|
||||||
# 'momentum': 0.9025, # SGD momentum
|
# 'momentum': 0.9025, # SGD momentum
|
||||||
# 'weight_decay': 0.0005417} # optimizer weight decay
|
# 'weight_decay': 0.0005417} # optimizer weight decay
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
cfg,
|
cfg,
|
||||||
data_cfg,
|
data_cfg,
|
||||||
|
@ -143,7 +142,6 @@ def train(
|
||||||
batch_size,
|
batch_size,
|
||||||
augment=True,
|
augment=True,
|
||||||
rect=False,
|
rect=False,
|
||||||
cache=True,
|
|
||||||
multi_scale=multi_scale)
|
multi_scale=multi_scale)
|
||||||
|
|
||||||
# Initialize distributed training
|
# Initialize distributed training
|
||||||
|
@ -168,7 +166,7 @@ def train(
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
|
|
||||||
# Remove old results
|
# Remove old results
|
||||||
for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg') + ['results.txt']:
|
for f in glob.glob('*_batch*.jpg') + glob.glob('results.txt'):
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
|
@ -307,7 +305,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch')
|
parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch')
|
||||||
parser.add_argument('--accumulate', type=int, default=1, help='accumulate gradient x batches before optimizing')
|
parser.add_argument('--accumulate', type=int, default=1, help='accumulate gradient x batches before optimizing')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
||||||
parser.add_argument('--data-cfg', type=str, default='data/coco_32img.data', help='coco.data file path')
|
parser.add_argument('--data-cfg', type=str, default='data/coco.data', help='coco.data file path')
|
||||||
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
|
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
|
||||||
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
|
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
|
|
|
@ -130,7 +130,7 @@ class LoadWebcam: # for inference
|
||||||
|
|
||||||
|
|
||||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False, cache=False,
|
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True, image_weights=False,
|
||||||
multi_scale=False):
|
multi_scale=False):
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
img_files = f.read().splitlines()
|
img_files = f.read().splitlines()
|
||||||
|
@ -190,11 +190,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
|
|
||||||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
|
||||||
|
|
||||||
# Preload images
|
|
||||||
if cache and (n < 1001): # preload all images into memory if possible
|
|
||||||
self.imgs = [cv2.imread(self.img_files[i]) for i in tqdm(range(n), desc='Reading images')]
|
|
||||||
|
|
||||||
# Preload labels (required for weighted CE training)
|
# Preload labels (required for weighted CE training)
|
||||||
|
self.imgs = [None] * n
|
||||||
self.labels = [np.zeros((0, 5))] * n
|
self.labels = [np.zeros((0, 5))] * n
|
||||||
iter = tqdm(self.label_files, desc='Reading labels') if n > 1000 else self.label_files
|
iter = tqdm(self.label_files, desc='Reading labels') if n > 1000 else self.label_files
|
||||||
for i, file in enumerate(iter):
|
for i, file in enumerate(iter):
|
||||||
|
@ -227,10 +224,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
label_path = self.label_files[index]
|
label_path = self.label_files[index]
|
||||||
|
|
||||||
# Load image
|
# Load image
|
||||||
if hasattr(self, 'imgs'): # preloaded
|
|
||||||
img = self.imgs[index]
|
img = self.imgs[index]
|
||||||
else:
|
if img is None:
|
||||||
img = cv2.imread(img_path) # BGR
|
img = cv2.imread(img_path) # BGR
|
||||||
|
if self.n < 1001:
|
||||||
|
self.imgs[index] = img # cache image into memory
|
||||||
assert img is not None, 'File Not Found ' + img_path
|
assert img is not None, 'File Not Found ' + img_path
|
||||||
|
|
||||||
# Augment colorspace
|
# Augment colorspace
|
||||||
|
|
Loading…
Reference in New Issue