This commit is contained in:
Glenn Jocher 2018-09-24 03:06:04 +02:00
parent 5d402ad31a
commit 313a3f6b0c
2 changed files with 8 additions and 7 deletions

View File

@ -14,11 +14,10 @@ parser.add_argument('-conf_thres', type=float, default=0.5, help='object confide
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression') parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation') parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('-img_size', type=int, default=416, help='size of each image dimension') parser.add_argument('-img_size', type=int, default=416, help='size of each image dimension')
parser.add_argument('-use_cuda', type=bool, default=True, help='whether to use cuda if available')
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
cuda = torch.cuda.is_available() and opt.use_cuda cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu') device = torch.device('cuda:0' if cuda else 'cpu')
# Configure run # Configure run

View File

@ -44,7 +44,7 @@ def main(opt):
# Get dataloader # Get dataloader
dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=opt.img_size, augment=True) dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=opt.img_size, augment=True)
# reload saved optimizer state # Reload saved optimizer state
start_epoch = 0 start_epoch = 0
best_loss = float('inf') best_loss = float('inf')
if opt.resume: if opt.resume:
@ -66,9 +66,11 @@ def main(opt):
# Set optimizer # Set optimizer
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters())) optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3,
optimizer.load_state_dict(checkpoint['optimizer']) momentum=.9, weight_decay=5e-4, nesterov=True)
if checkpoint['optimizer'] is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1 start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['best_loss'] best_loss = checkpoint['best_loss']