diff --git a/train.py b/train.py index e12a0937..3fd6972f 100644 --- a/train.py +++ b/train.py @@ -87,7 +87,7 @@ def train( t0 = time.time() model_info(model) - n_burnin = min(dataloader.nB / 5, 1000) # number of burn-in batches + n_burnin = min(round(dataloader.nB / 5), 1000) # number of burn-in batches for epoch in range(epochs): epoch += start_epoch