This commit is contained in:
Glenn Jocher 2019-03-20 22:10:18 +02:00
parent 9885903baf
commit 327aaebd7c
1 changed files with 2 additions and 1 deletions

View File

@ -36,6 +36,7 @@ def train(
# Get dataloader # Get dataloader
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
# dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size, num_workers=0)
lr0 = 0.001 # initial learning rate lr0 = 0.001 # initial learning rate
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
@ -81,7 +82,7 @@ def train(
# Start training # Start training
t0 = time.time() t0 = time.time()
model_info(model) model_info(model)
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
epoch += start_epoch epoch += start_epoch