updates
This commit is contained in:
parent
9885903baf
commit
327aaebd7c
3
train.py
3
train.py
|
@ -36,6 +36,7 @@ def train(
|
|||
|
||||
# Get dataloader
|
||||
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
||||
# dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size, num_workers=0)
|
||||
|
||||
lr0 = 0.001 # initial learning rate
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
|
@ -81,7 +82,7 @@ def train(
|
|||
# Start training
|
||||
t0 = time.time()
|
||||
model_info(model)
|
||||
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
||||
n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
epoch += start_epoch
|
||||
|
|
Loading…
Reference in New Issue