Skip to content

SequentialLR scheduler incorrect initialization #72874

@antoniojkim

Description

@antoniojkim

🐛 Describe the bug

What was happening is that when we have multiple learning rate schedulers, the order in which they are being initialized is not being taken into account. This is a problem if they were being initialized in sequential order (as one might intuitively do).

Each scheduler calls step() on initialization and sets the lr in its optimizer's params_groups. However, this means that epoch 0 will be using the lr that was set by the very last scheduler (in the case of initializing schedulers sequentially) instead of the first scheduler.

So, for example, we get this incorrect behaviour here:

import torch

optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1)
print(optimizer.param_groups[0]["lr"])  # 0.1, as expected

schedulers = [
    torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
    torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1)
]
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=[10])

print(optimizer.param_groups[0]["lr"])  # 0.01, which is incorrect. It should be 0.1

Evidently, the optimizer's learning rate was set by the second learning rate scheduler on initialization which is incorrect behaviour.

A possible fix for this issue is proposed in #72856

Versions

Collecting environment information...
PyTorch version: 1.10.0a0+git4cf0146
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: Could not collect
Clang version: 3.4.2 (tags/RELEASE_34/dot2-final)
CMake version: version 3.15.5
Libc version: glibc-2.2.5

Python version: 3.7.4 (default, Jan 28 2022, 18:52:51) [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.49.1.el7.x86_64-x86_64-with-centos-7.9.2009-Core
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.5
[pip3] torch==1.10.0a0+gite63eb02
[pip3] torch-xla==1.10
[pip3] torchmetrics==0.6.2
[pip3] torchvision==0.11.0a0+e7ec7e2
[conda] Could not collect

cc @vincentqb @jbschlosser @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: LrSchedulermodule: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions