updates
This commit is contained in:
parent
270724e507
commit
806d7b92d8
6
train.py
6
train.py
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
@ -88,8 +87,10 @@ def train():
|
||||||
else:
|
else:
|
||||||
pg0 += [v] # parameter group 0
|
pg0 += [v] # parameter group 0
|
||||||
|
|
||||||
# optimizer = optim.Adam(pg0, lr=hyp['lr0'])
|
if opt.adam:
|
||||||
|
optimizer = optim.Adam(pg0, lr=hyp['lr0'])
|
||||||
# optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1)
|
# optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1)
|
||||||
|
else:
|
||||||
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
||||||
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
|
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
|
||||||
del pg0, pg1
|
del pg0, pg1
|
||||||
|
@ -388,6 +389,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
|
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
|
||||||
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
|
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
|
||||||
parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1')
|
parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1')
|
||||||
|
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
|
||||||
parser.add_argument('--var', type=float, help='debug variable')
|
parser.add_argument('--var', type=float, help='debug variable')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.weights = 'weights/last.pt' if opt.resume else opt.weights
|
opt.weights = 'weights/last.pt' if opt.resume else opt.weights
|
||||||
|
|
Loading…
Reference in New Issue