align loss to darknet
This commit is contained in:
parent
a75119b8f0
commit
396a71001e
2
test.py
2
test.py
|
@ -7,7 +7,7 @@ parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
|
||||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
|
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
|
||||||
parser.add_argument('-weights_path', type=str, default='checkpoints/yolov3.pt', help='path to weights file')
|
parser.add_argument('-weights_path', type=str, default='checkpoints/latest.pt', help='path to weights file')
|
||||||
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
|
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
|
||||||
parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
|
||||||
parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold')
|
parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold')
|
||||||
|
|
14
train.py
14
train.py
|
@ -6,7 +6,7 @@ from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-epochs', type=int, default=1, help='number of epochs')
|
parser.add_argument('-epochs', type=int, default=160, help='number of epochs')
|
||||||
parser.add_argument('-batch_size', type=int, default=12, help='size of each image batch')
|
parser.add_argument('-batch_size', type=int, default=12, help='size of each image batch')
|
||||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path')
|
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path')
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
|
@ -69,9 +69,9 @@ def main(opt):
|
||||||
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3,
|
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3,
|
||||||
momentum=.9, weight_decay=5e-4, nesterov=True)
|
momentum=.9, weight_decay=5e-4, nesterov=True)
|
||||||
|
|
||||||
|
start_epoch = checkpoint['epoch'] + 1
|
||||||
if checkpoint['optimizer'] is not None:
|
if checkpoint['optimizer'] is not None:
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
start_epoch = checkpoint['epoch'] + 1
|
|
||||||
best_loss = checkpoint['best_loss']
|
best_loss = checkpoint['best_loss']
|
||||||
|
|
||||||
del checkpoint # current, saved
|
del checkpoint # current, saved
|
||||||
|
@ -115,12 +115,10 @@ def main(opt):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# SGD burn-in
|
# SGD burn-in
|
||||||
# if (epoch == 0) & (i <= 1000):
|
if (epoch == 0) & (i <= 1000):
|
||||||
# power = 4
|
lr = 1e-3 * (i / 1000) ** 4
|
||||||
# lr = 1e-3 * (i / 1000) ** power
|
for g in optimizer.param_groups:
|
||||||
# for g in optimizer.param_groups:
|
g['lr'] = lr
|
||||||
# g['lr'] = lr
|
|
||||||
# # print('SGD Burn-In LR = %9.5g' % lr, end='')
|
|
||||||
|
|
||||||
# Compute loss, compute gradient, update parameters
|
# Compute loss, compute gradient, update parameters
|
||||||
loss = model(imgs.to(device), targets, requestPrecision=True)
|
loss = model(imgs.to(device), targets, requestPrecision=True)
|
||||||
|
|
Loading…
Reference in New Issue