This commit is contained in:
Glenn Jocher 2019-09-11 14:25:48 +02:00
parent 270724e507
commit 806d7b92d8
1 changed files with 6 additions and 4 deletions

View File

@ -1,5 +1,4 @@
import argparse import argparse
import time
import torch.distributed as dist import torch.distributed as dist
import torch.optim as optim import torch.optim as optim
@ -88,8 +87,10 @@ def train():
else: else:
pg0 += [v] # parameter group 0 pg0 += [v] # parameter group 0
# optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam:
optimizer = optim.Adam(pg0, lr=hyp['lr0'])
# optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1) # optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1)
else:
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
del pg0, pg1 del pg0, pg1
@ -388,6 +389,7 @@ if __name__ == '__main__':
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training') parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1') parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1')
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--var', type=float, help='debug variable') parser.add_argument('--var', type=float, help='debug variable')
opt = parser.parse_args() opt = parser.parse_args()
opt.weights = 'weights/last.pt' if opt.resume else opt.weights opt.weights = 'weights/last.pt' if opt.resume else opt.weights