ONNX compatibility updates
This commit is contained in:
parent
8ad8a64a0d
commit
eec0dc7b6c
17
train.py
17
train.py
|
@ -31,7 +31,9 @@ def train(
|
|||
device = torch_utils.select_device()
|
||||
print("Using device: \"{}\"".format(device))
|
||||
|
||||
if not multi_scale:
|
||||
if multi_scale: # pass maximum multi_scale size
|
||||
img_size = 608
|
||||
else:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
os.makedirs(weights_path, exist_ok=True)
|
||||
|
@ -47,9 +49,6 @@ def train(
|
|||
model = Darknet(net_config_path, img_size)
|
||||
|
||||
# Get dataloader
|
||||
if multi_scale: # pass maximum multi_scale size
|
||||
img_size = 608
|
||||
|
||||
dataloader = load_images_and_labels(train_path, batch_size=batch_size, img_size=img_size,
|
||||
multi_scale=multi_scale, augment=True)
|
||||
|
||||
|
@ -105,7 +104,7 @@ def train(
|
|||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||
|
||||
model_info(model)
|
||||
t0, t1 = time.time(), time.time()
|
||||
t0 = time.time()
|
||||
mean_recall, mean_precision = 0, 0
|
||||
for epoch in range(epochs):
|
||||
epoch += start_epoch
|
||||
|
@ -183,8 +182,8 @@ def train(
|
|||
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (i, len(dataloader) - 1), rloss['x'],
|
||||
rloss['y'], rloss['w'], rloss['h'], rloss['conf'], rloss['cls'],
|
||||
rloss['loss'], mean_precision, mean_recall, model.losses['nT'], model.losses['TP'],
|
||||
model.losses['FP'], model.losses['FN'], time.time() - t1)
|
||||
t1 = time.time()
|
||||
model.losses['FP'], model.losses['FN'], time.time() - t0)
|
||||
t0 = time.time()
|
||||
print(s)
|
||||
|
||||
# Update best loss
|
||||
|
@ -228,10 +227,6 @@ def train(
|
|||
with open('results.txt', 'a') as file:
|
||||
file.write(s + '%11.3g' * 3 % (mAP, P, R) + '\n')
|
||||
|
||||
# Save final model
|
||||
dt = time.time() - t0
|
||||
print('Finished %g epochs in %.2fs (%.2fs/epoch)' % (epoch, dt, dt / (epoch + 1)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
Loading…
Reference in New Issue