diff --git a/train.py b/train.py index a657d378..3f30b539 100644 --- a/train.py +++ b/train.py @@ -202,10 +202,6 @@ def train(): model.train() print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) - # Update scheduler - if epoch > 0: - scheduler.step() - # Freeze backbone at epoch 0, unfreeze at epoch 1 (optional) freeze_backbone = False if freeze_backbone and epoch < 2: @@ -286,6 +282,9 @@ def train(): # end batch ------------------------------------------------------------------------------------------------ + # Update scheduler + scheduler.step() + # Process epoch results final_epoch = epoch + 1 == epochs if opt.prebias: