diff --git a/train.py b/train.py index aa813d83..b13e8e75 100644 --- a/train.py +++ b/train.py @@ -193,7 +193,8 @@ def train(cfg, ('Epoch', 'gpu_mem', 'GIoU/xy', 'wh', 'obj', 'cls', 'total', 'targets', 'img_size')) # Update scheduler - scheduler.step() + if epoch > 0: + scheduler.step() # Freeze backbone at epoch 0, unfreeze at epoch 1 (optional) freeze_backbone = False