ONNX compatibility updates
This commit is contained in:
parent
8ad8a64a0d
commit
eec0dc7b6c
17
train.py
17
train.py
|
@ -31,7 +31,9 @@ def train(
|
||||||
device = torch_utils.select_device()
|
device = torch_utils.select_device()
|
||||||
print("Using device: \"{}\"".format(device))
|
print("Using device: \"{}\"".format(device))
|
||||||
|
|
||||||
if not multi_scale:
|
if multi_scale: # pass maximum multi_scale size
|
||||||
|
img_size = 608
|
||||||
|
else:
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
os.makedirs(weights_path, exist_ok=True)
|
os.makedirs(weights_path, exist_ok=True)
|
||||||
|
@ -47,9 +49,6 @@ def train(
|
||||||
model = Darknet(net_config_path, img_size)
|
model = Darknet(net_config_path, img_size)
|
||||||
|
|
||||||
# Get dataloader
|
# Get dataloader
|
||||||
if multi_scale: # pass maximum multi_scale size
|
|
||||||
img_size = 608
|
|
||||||
|
|
||||||
dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size,
|
dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size,
|
||||||
multi_scale=multi_scale, augment=True)
|
multi_scale=multi_scale, augment=True)
|
||||||
|
|
||||||
|
@ -105,7 +104,7 @@ def train(
|
||||||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||||
|
|
||||||
model_info(model)
|
model_info(model)
|
||||||
t0, t1 = time.time(), time.time()
|
t0 = time.time()
|
||||||
mean_recall, mean_precision = 0, 0
|
mean_recall, mean_precision = 0, 0
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
@ -183,8 +182,8 @@ def train(
|
||||||
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
|
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
|
||||||
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
|
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
|
||||||
rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
|
rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
|
||||||
model.losses['FP'], model.losses['FN'], time.time() - t1)
|
model.losses['FP'], model.losses['FN'], time.time() - t0)
|
||||||
t1 = time.time()
|
t0 = time.time()
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
# Update best loss
|
# Update best loss
|
||||||
|
@ -228,10 +227,6 @@ def train(
|
||||||
with open('results.txt', 'a') as file:
|
with open('results.txt', 'a') as file:
|
||||||
file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n')
|
file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n')
|
||||||
|
|
||||||
# Save final model
|
|
||||||
dt = time.time() - t0
|
|
||||||
print('Finished %g epochs in %.2fs (%.2fs/epoch)' % (epoch, dt, dt / (epoch + 1)))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
Loading…
Reference in New Issue