ONNX compatibility updates

This commit is contained in:
Glenn Jocher 2018-12-28 20:09:06 +01:00
parent 8ad8a64a0d
commit eec0dc7b6c
1 changed files with 6 additions and 11 deletions

View File

@ -31,7 +31,9 @@ def train(
device = torch_utils.select_device() device = torch_utils.select_device()
print("Using device: \"{}\"".format(device)) print("Using device: \"{}\"".format(device))
if not multi_scale: if multi_scale: # pass maximum multi_scale size
img_size = 608
else:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
os.makedirs(weights_path, exist_ok=True) os.makedirs(weights_path, exist_ok=True)
@ -47,9 +49,6 @@ def train(
model = Darknet(net_config_path, img_size) model = Darknet(net_config_path, img_size)
# Get dataloader # Get dataloader
if multi_scale: # pass maximum multi_scale size
img_size = 608
dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size, dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size,
multi_scale=multi_scale, augment=True) multi_scale=multi_scale, augment=True)
@ -105,7 +104,7 @@ def train(
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
model_info(model) model_info(model)
t0, t1 = time.time(), time.time() t0 = time.time()
mean_recall, mean_precision = 0, 0 mean_recall, mean_precision = 0, 0
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
@ -183,8 +182,8 @@ def train(
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'], '%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'], rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'], rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
model.losses['FP'], model.losses['FN'], time.time() - t1) model.losses['FP'], model.losses['FN'], time.time() - t0)
t1 = time.time() t0 = time.time()
print(s) print(s)
# Update best loss # Update best loss
@ -228,10 +227,6 @@ def train(
with open('results.txt', 'a') as file: with open('results.txt', 'a') as file:
file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n') file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n')
# Save final model
dt = time.time() - t0
print('Finished %g epochs in %.2fs (%.2fs/epoch)' % (epoch, dt, dt / (epoch + 1)))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()