diff --git a/train.py b/train.py index 3fd6972f..a0dd3222 100644 --- a/train.py +++ b/train.py @@ -87,7 +87,7 @@ def train( t0 = time.time() model_info(model) - n_burnin = min(round(dataloader.nB / 5), 1000) # number of burn-in batches + n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches for epoch in range(epochs): epoch += start_epoch