train.py remove hardcoded weights/ path for weights.
If I want to store my weights in 'weights2' path: python train.py --weights-path weights2 Default is the same: weights
This commit is contained in:
parent
9c0c1f23ab
commit
868a116750
42
train.py
42
train.py
|
@ -11,6 +11,11 @@ from utils import torch_utils
|
||||||
# Import test.py to get mAP after each epoch
|
# Import test.py to get mAP after each epoch
|
||||||
import test
|
import test
|
||||||
|
|
||||||
|
DARKNET_WEIGHTS_FILENAME = 'darknet53.conv.74'
|
||||||
|
DARKNET_WEIGHTS_URL = 'https://pjreddie.com/media/files/{}'.format(
|
||||||
|
DARKNET_WEIGHTS_FILENAME
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
net_config_path,
|
net_config_path,
|
||||||
|
@ -19,6 +24,7 @@ def train(
|
||||||
resume=False,
|
resume=False,
|
||||||
epochs=100,
|
epochs=100,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
weights_path='weights',
|
||||||
report=False,
|
report=False,
|
||||||
multi_scale=False,
|
multi_scale=False,
|
||||||
freeze_backbone=True,
|
freeze_backbone=True,
|
||||||
|
@ -31,12 +37,14 @@ def train(
|
||||||
if not multi_scale:
|
if not multi_scale:
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
os.makedirs('weights', exist_ok=True)
|
os.makedirs(weights_path, exist_ok=True)
|
||||||
|
latest_weights_file = os.path.join(weights_path, 'latest.pt')
|
||||||
|
best_weights_file = os.path.join(weights_path, 'best.pt')
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data_config = parse_data_config(data_config_path)
|
data_config = parse_data_config(data_config_path)
|
||||||
num_classes = int(data_config['classes'])
|
num_classes = int(data_config['classes'])
|
||||||
train_path = '../coco/trainvalno5k.txt'
|
train_path = data_config['train']
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(net_config_path, img_size)
|
model = Darknet(net_config_path, img_size)
|
||||||
|
@ -50,7 +58,7 @@ def train(
|
||||||
|
|
||||||
lr0 = 0.001
|
lr0 = 0.001
|
||||||
if resume:
|
if resume:
|
||||||
checkpoint = torch.load('weights/latest.pt', map_location='cpu')
|
checkpoint = torch.load(latest_weights_file, map_location='cpu')
|
||||||
|
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
|
@ -79,9 +87,13 @@ def train(
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
# Initialize model with darknet53 weights (optional)
|
# Initialize model with darknet53 weights (optional)
|
||||||
if not os.path.isfile('weights/darknet53.conv.74'):
|
def_weight_file = os.path.join(weights_path, DARKNET_WEIGHTS_FILENAME)
|
||||||
os.system('wget https://pjreddie.com/media/files/darknet53.conv.74 -P weights')
|
if not os.path.isfile(def_weight_file):
|
||||||
load_weights(model, 'weights/darknet53.conv.74')
|
os.system('wget {} -P {}'.format(
|
||||||
|
DARKNET_WEIGHTS_URL,
|
||||||
|
weights_path))
|
||||||
|
assert os.path.isfile(def_weight_file)
|
||||||
|
load_weights(model, def_weight_file)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21')
|
raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21')
|
||||||
|
@ -187,21 +199,29 @@ def train(
|
||||||
'best_loss': best_loss,
|
'best_loss': best_loss,
|
||||||
'model': model.state_dict(),
|
'model': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, 'weights/latest.pt')
|
torch.save(checkpoint, latest_weights_file)
|
||||||
|
|
||||||
# Save best checkpoint
|
# Save best checkpoint
|
||||||
if best_loss == loss_per_target:
|
if best_loss == loss_per_target:
|
||||||
os.system('cp weights/latest.pt weights/best.pt')
|
os.system('cp {} {}'.format(
|
||||||
|
latest_weights_file,
|
||||||
|
best_weights_file,
|
||||||
|
))
|
||||||
|
|
||||||
# Save backup weights every 5 epochs
|
# Save backup weights every 5 epochs
|
||||||
if (epoch > 0) & (epoch % 5 == 0):
|
if (epoch > 0) & (epoch % 5 == 0):
|
||||||
os.system('cp weights/latest.pt weights/backup' + str(epoch) + '.pt')
|
backup_file_name = 'backup{}.pt'.format(epoch)
|
||||||
|
backup_file_path = os.path.join(weights_path, backup_file_name)
|
||||||
|
os.system('cp {} {}'.format(
|
||||||
|
latest_weights_file,
|
||||||
|
backup_file_path,
|
||||||
|
))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
mAP, R, P = test.test(
|
mAP, R, P = test.test(
|
||||||
net_config_path,
|
net_config_path,
|
||||||
data_config_path,
|
data_config_path,
|
||||||
'weights/latest.pt',
|
latest_weights_file,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
)
|
)
|
||||||
|
@ -224,6 +244,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
parser.add_argument('--multi-scale', default=False, help='random image sizes per batch 320 - 608')
|
parser.add_argument('--multi-scale', default=False, help='random image sizes per batch 320 - 608')
|
||||||
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
|
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
|
||||||
|
parser.add_argument('--weights-path', type=str, default='weights', help='path to store weights')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)')
|
parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)')
|
||||||
parser.add_argument('--freeze-darknet53', default=False, help='freeze darknet53.conv.74 layers for first epoch')
|
parser.add_argument('--freeze-darknet53', default=False, help='freeze darknet53.conv.74 layers for first epoch')
|
||||||
|
@ -241,6 +262,7 @@ if __name__ == '__main__':
|
||||||
resume=opt.resume,
|
resume=opt.resume,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
|
weights_path=opt.weights_path,
|
||||||
report=opt.report,
|
report=opt.report,
|
||||||
multi_scale=opt.multi_scale,
|
multi_scale=opt.multi_scale,
|
||||||
freeze_backbone=opt.freeze_darknet53,
|
freeze_backbone=opt.freeze_darknet53,
|
||||||
|
|
Loading…
Reference in New Issue