From 313a3f6b0ca9821c23b99063da4c7d47b6414e2e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 24 Sep 2018 03:06:04 +0200 Subject: [PATCH] updates --- test.py | 3 +-- train.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test.py b/test.py index 64415126..b88faf0a 100644 --- a/test.py +++ b/test.py @@ -14,11 +14,10 @@ parser.add_argument('-conf_thres', type=float, default=0.5, help='object confide parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression') parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation') parser.add_argument('-img_size', type=int, default=416, help='size of each image dimension') -parser.add_argument('-use_cuda', type=bool, default=True, help='whether to use cuda if available') opt = parser.parse_args() print(opt) -cuda = torch.cuda.is_available() and opt.use_cuda +cuda = torch.cuda.is_available() device = torch.device('cuda:0' if cuda else 'cpu') # Configure run diff --git a/train.py b/train.py index 7ab90867..2d959f52 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,7 @@ def main(opt): # Get dataloader dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=opt.img_size, augment=True) - # reload saved optimizer state + # Reload saved optimizer state start_epoch = 0 best_loss = float('inf') if opt.resume: @@ -66,11 +66,13 @@ def main(opt): # Set optimizer # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) - optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters())) - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, + momentum=.9, weight_decay=5e-4, nesterov=True) - start_epoch = checkpoint['epoch'] + 1 - best_loss = checkpoint['best_loss'] + if checkpoint['optimizer'] is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + 1 + best_loss = checkpoint['best_loss'] del checkpoint # current, saved else: