updates
This commit is contained in:
parent
9885903baf
commit
327aaebd7c
3
train.py
3
train.py
|
@ -36,6 +36,7 @@ def train(
|
||||||
|
|
||||||
# Get dataloader
|
# Get dataloader
|
||||||
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
||||||
|
# dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size, num_workers=0)
|
||||||
|
|
||||||
lr0 = 0.001 # initial learning rate
|
lr0 = 0.001 # initial learning rate
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
|
@ -81,7 +82,7 @@ def train(
|
||||||
# Start training
|
# Start training
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
model_info(model)
|
model_info(model)
|
||||||
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
Loading…
Reference in New Issue