This commit is contained in:
Glenn Jocher 2019-06-12 11:50:24 +02:00
parent 9c328b1b0e
commit 5edb0ec40d
2 changed files with 6 additions and 7 deletions

View File

@ -64,7 +64,7 @@ def train(
epochs=100, # 500200 batches at bs 4, 117263 images = 68 epochs epochs=100, # 500200 batches at bs 4, 117263 images = 68 epochs
batch_size=16, batch_size=16,
accumulate=4, # effective bs = 64 = batch_size * accumulate accumulate=4, # effective bs = 64 = batch_size * accumulate
multi_scale=False, multi_scale=True,
freeze_backbone=False, freeze_backbone=False,
transfer=False # Transfer learning (train only YOLO layers) transfer=False # Transfer learning (train only YOLO layers)
): ):
@ -73,12 +73,13 @@ def train(
latest = weights + 'latest.pt' latest = weights + 'latest.pt'
best = weights + 'best.pt' best = weights + 'best.pt'
device = torch_utils.select_device() device = torch_utils.select_device()
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
if multi_scale: if multi_scale:
img_size = round((img_size / 32) * 1.5) * 32 # initiate with maximum multi_scale size min_size = round(img_size / 32 / 1.5)
max_size = round(img_size / 32 * 1.5)
img_size = max_size * 32 # initiate with maximum multi_scale size
# opt.num_workers = 0 # bug https://github.com/ultralytics/yolov3/issues/174 # opt.num_workers = 0 # bug https://github.com/ultralytics/yolov3/issues/174
else:
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
# Configure run # Configure run
data_dict = parse_data_cfg(data_cfg) data_dict = parse_data_cfg(data_cfg)
@ -244,10 +245,7 @@ def train(
# Multi-Scale training (67% - 150%) every 10 batches # Multi-Scale training (67% - 150%) every 10 batches
if multi_scale and (i + 1) % 10 == 0: if multi_scale and (i + 1) % 10 == 0:
min_size = round(img_size / 32 / 1.5)
max_size = round(img_size / 32 * 1.5)
dataset.img_size = random.choice(range(min_size, max_size + 1)) * 32 dataset.img_size = random.choice(range(min_size, max_size + 1)) * 32
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=opt.num_workers, num_workers=opt.num_workers,

View File

@ -153,6 +153,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
replace('.bmp', '.txt'). replace('.bmp', '.txt').
replace('.png', '.txt') for x in self.img_files] replace('.png', '.txt') for x in self.img_files]
multi_scale = False
if multi_scale: if multi_scale:
s = img_size / 32 s = img_size / 32
self.multi_scale = ((np.linspace(0.5, 1.5, nb) * s).round().astype(np.int) * 32) self.multi_scale = ((np.linspace(0.5, 1.5, nb) * s).round().astype(np.int) * 32)