From eec0dc7b6c12fe89bdd6907269e0e05986e8e9b6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 28 Dec 2018 20:09:06 +0100 Subject: [PATCH] ONNX compatibility updates --- train.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index 0ad9c479..9232139d 100644 --- a/train.py +++ b/train.py @@ -31,7 +31,9 @@ def train( device = torch_utils.select_device() print("Using device: \"{}\"".format(device)) - if not multi_scale: + if multi_scale: # pass maximum multi_scale size + img_size = 608 + else: torch.backends.cudnn.benchmark = True os.makedirs(weights_path, exist_ok=True) @@ -47,9 +49,6 @@ def train( model = Darknet(net_config_path, img_size) # Get dataloader - if multi_scale: # pass maximum multi_scale size - img_size = 608 - dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size, multi_scale=multi_scale, augment=True) @@ -105,7 +104,7 @@ def train( # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) model_info(model) - t0, t1 = time.time(), time.time() + t0 = time.time() mean_recall, mean_precision = 0, 0 for epoch in range(epochs): epoch += start_epoch @@ -183,8 +182,8 @@ def train( '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'], - model.losses['FP'], model.losses['FN'], time.time() - t1) - t1 = time.time() + model.losses['FP'], model.losses['FN'], time.time() - t0) + t0 = time.time() print(s) # Update best loss @@ -228,10 +227,6 @@ def train( with open('results.txt', 'a') as file: file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n') - # Save final model - dt = time.time() - t0 - print('Finished %g epochs in %.2fs (%.2fs/epoch)' % (epoch, dt, dt / (epoch + 1))) - if __name__ == '__main__': parser = argparse.ArgumentParser()