Skip to content

Fix SequentialLR initialization#72856

Closed
antoniojkim wants to merge 12 commits intopytorch:masterfrom
antoniojkim:antoniojkim/fix_sequential_lr_init
Closed

Fix SequentialLR initialization#72856
antoniojkim wants to merge 12 commits intopytorch:masterfrom
antoniojkim:antoniojkim/fix_sequential_lr_init

Conversation

@antoniojkim
Copy link
Copy Markdown
Collaborator

What was happening is that when we have multiple learning rate schedulers, the order in which they are being initialized is not being taken into account. This is a problem if they were being initialized in sequential order (as one might intuitively do).

Each scheduler calls step() on initialization and sets the lr in its optimizer's params_groups. However, this means that step 0 will be using the lr that was set by the very last scheduler (in the case of initializing schedulers sequentially) instead of the first scheduler.

The fix in this PR, addresses the above bug by performing a call to the appropriate scheduler on initialization after decrementing the last_epoch values in order to keep them the same post-step. This will ensure that the correct scheduler is the one setting the lr values for the optimizer's param_groups

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 15, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/antoniojkim/pytorch/blob/50a4f970b36acf76b7f4015875336d1866a78f88/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Feb 15, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit e505c16 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 15, 2022
@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Feb 15, 2022

Is there an issue associated with this?

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

Is there an issue associated with this?

I did not open an issue for this. Its just something I noticed and decided to contribute a fix for. Should I open an issue just for reference?

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

antoniojkim commented Feb 15, 2022

okay, so this test is failing

pytorch/test/test_optim.py

Lines 1441 to 1454 in 5dd0732

def test_get_last_lr_sequentiallr(self):
epochs = 12
milestones = [3, 6]
schedulers = [None] * 3
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
constant_lr_target = [0.005] * 3
exponential_lr_target = [0.05, 0.04, 0.032]
step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
single_targets = constant_lr_target + exponential_lr_target + step_lr_target
targets = [single_targets, [x * 10 for x in single_targets]]
self._test_get_last_lr(scheduler, targets, epochs)

but I'm not entirely convinced that this test is valid. Specifically, this line

targets = [single_targets, [x * 10 for x in single_targets]]

@jaketae what is the purpose of multiplying the single_targets by 10?

It was added in #70558. @albanD do you have any ideas?

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Feb 15, 2022

We usually discuss these things on issues before working on an implementation.

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

We usually discuss these things on issues before working on an implementation.

Issue created here: #72874

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

antoniojkim commented Feb 15, 2022

Going back to the SequentialLR tests that appear to be failing, I'm not convinced they are correct to begin with.

Another test that is failing is this one:

pytorch/test/test_optim.py

Lines 1429 to 1439 in 84729ce

def test_sequentiallr3(self):
epochs = 12
schedulers = [None] * 3
targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032]
+ [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]]
milestones = [3, 6]
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
self._test(scheduler, targets, epochs)

The optimizer is set as

pytorch/test/test_optim.py

Lines 921 to 923 in 84729ce

self.opt = SGD(
[{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
lr=0.05)

The expected target for epoch 0 is set as 0.005 when it should be 0.05 seeing as its a ConstantLR with factor 0.1.

Is there something I'm missing here?

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@albanD Just pinging to get some help on this

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix sounds generally ok.
I'm wondering if we can refactor this to reduce duplicated code (the bisect_right in particular).
Also this will need testing!

Comment thread collect_env.py Outdated
Comment thread torch/optim/lr_scheduler.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still ok in the case where 0 is passed to the step() function below?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be. If 0 is passed in, then last_epoch is set in step() anyways. So, this decrement shouldn't have any effect in that case.

@antoniojkim antoniojkim force-pushed the antoniojkim/fix_sequential_lr_init branch from 841aa88 to aa3b75b Compare February 22, 2022 15:43
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@albanD Thanks for the initial review! What are your thoughts on the comments I made above about the current tests not being correct?

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Feb 22, 2022

Not sure about the test tbh.
cc @jbschlosser who might have some time to dive into this?

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

antoniojkim commented Mar 3, 2022

@albanD Just checking up on this PR again. What's the right course of action here? Should I modify those tests to reflect what I think should be correct values?

@jbschlosser
Copy link
Copy Markdown
Contributor

Sorry for the delay - I can look into this. Will post findings shortly.

Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antoniojkim I don't see this fix addressing the issue - your reproduction code still gives the same invalid results after the fix.

This may have been mentioned already, but the core of the issue is that the _LRSchedulers (specifically ConstantLR in this case) mutate the optimizer during initialization. This happens regardless of SequentialLR usage.

import torch

optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1)
print(optimizer.param_groups[0]["lr"])  # 0.1
torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1)
print(optimizer.param_groups[0]["lr"])  # 0.01
torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1)
print(optimizer.param_groups[0]["lr"])  # 0.001

Not sure what a proper fix is yet tbh and unfortunately I don't have time to dig into this more atm. Happy to review any possible fixes you come up with, although I will be out next week so there will be some delay.

@antoniojkim antoniojkim force-pushed the antoniojkim/fix_sequential_lr_init branch from 08def64 to ed79f21 Compare March 8, 2022 21:49
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser I just added a reset_optimizer_lr method that I think may fix it. Can you please take a look and let me know what you think?

Comment thread torch/optim/lr_scheduler.py Outdated
@antoniojkim antoniojkim force-pushed the antoniojkim/fix_sequential_lr_init branch from ed79f21 to 4b545c2 Compare March 15, 2022 14:02
@antoniojkim antoniojkim requested a review from jbschlosser March 15, 2022 14:02
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser I fixed the reset_optimizer_lr method to use the initial_lr key in the param group instead. I think that should cover the case that you mentioned.

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser gentle reminder to please review this PR again

@antoniojkim antoniojkim force-pushed the antoniojkim/fix_sequential_lr_init branch from c6b1b7e to 565220e Compare June 15, 2022 21:10
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

Unless I'm mistaken I don't think the failing tests are related to the changes I made. Can I get another review @jbschlosser?

@antoniojkim antoniojkim requested a review from jbschlosser June 16, 2022 19:31
Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, looks correct now wrt tests! One minor API thing and I think we can finally get it merged

Comment thread torch/optim/lr_scheduler.py Outdated
self.step()

@contextmanager
def init_optimizer_lr(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to name this privately (_init_optimizer_lr) and we're good to go :)

- This is primarily to prevent the changes from affecting
  ChainedScheduler
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser There was a bug in the _init_optimizer_lr which meant that my changes weren't actually having an effect. That's why the tests were passing before. They seem to be failing again on mismatching lr values.

I added a test that verifies that the example case that I outlined in the issue #72874 that appears to be passing. So, I believe the behaviour is correct now. Can you please check the values in the failing sequential_lr tests? I'm suspicious of those target values.

@jbschlosser
Copy link
Copy Markdown
Contributor

Hey @antoniojkim, I just checked the values for the first failing test:

======================================================================
FAIL [0.003s]: test_get_last_lr_sequentiallr (__main__.TestLRScheduler)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_optim.py", line 1502, in test_get_last_lr_sequentiallr
    self._test_get_last_lr(scheduler, targets, epochs)
  File "test_optim.py", line 2257, in _test_get_last_lr
    epoch, t, r), atol=1e-5, rtol=0)
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 2273, in assertEqual
    msg=(lambda generated_msg: f"{generated_msg} : {msg}") if isinstance(msg, str) and self.longMessage else msg,
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Scalars are not close!
Absolute difference: 0.045000000000000005 (up to 1e-05 allowed)
Relative difference: 0.9 (up to 0 allowed)
The failure occurred for item [0][0] : LR is wrong in epoch 0: expected [0.005, 0.05], got [0.05, 0.5]

pytorch/test/test_optim.py

Lines 1473 to 1486 in 2ede287

def test_get_last_lr_sequentiallr(self):
epochs = 12
milestones = [3, 6]
schedulers = [None] * 3
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
constant_lr_target = [0.005] * 3
exponential_lr_target = [0.05, 0.04, 0.032]
step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
single_targets = constant_lr_target + exponential_lr_target + step_lr_target
targets = [single_targets, [x * 10 for x in single_targets]]
self._test_get_last_lr(scheduler, targets, epochs)

pytorch/test/test_optim.py

Lines 941 to 943 in 2ede287

self.opt = SGD(
[{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
lr=0.05)

Given a ConstantLR with factor=0.1 as the first scheduler, the target values of [0.005, 0.05] seem correct to me.

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser I found the issue. Was related to the last_epoch variable being incremented. Added a decrement to adjust accordingly.

All tests green now!

Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, LGTM :)

@jbschlosser
Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@antoniojkim your PR has been successfully merged.

@github-actions
Copy link
Copy Markdown
Contributor

Hey @antoniojkim.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@antoniojkim
Copy link
Copy Markdown
Collaborator Author

Nice work, LGTM :)

Awesome. Thanks for all your help with this!

@jbschlosser jbschlosser added release notes: python_frontend python frontend release notes category topic: bug fixes topic category labels Jun 21, 2022
facebook-github-bot pushed a commit that referenced this pull request Jun 22, 2022
Summary:
What was happening is that when we have multiple learning rate schedulers, the order in which they are being initialized is not being taken into account. This is a problem if they were being initialized in sequential order (as one might intuitively do).

Each scheduler calls `step()` on initialization and sets the `lr` in its optimizer's `params_groups`. However, this means that step 0 will be using the `lr` that was set by the very last scheduler (in the case of initializing schedulers sequentially) instead of the first scheduler.

The fix in this PR, addresses the above bug by performing a call to the appropriate scheduler on initialization after decrementing the `last_epoch` values in order to keep them the same post-step. This will ensure that the correct scheduler is the one setting the `lr` values for the optimizer's `param_groups`

Pull Request resolved: #72856
Approved by: https://github.com/jbschlosser

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/765b6a8fab02a55cf85058037f9ab89fb00a5b23

Reviewed By: atalman

Differential Revision: D37327411

fbshipit-source-id: dc461f1de257fe13f54fede2b093644bc6e51a89
@antoniojkim
Copy link
Copy Markdown
Collaborator Author

@jbschlosser I think @Queuecumber wanted this fix in the next PyTorch release if possible. Is there still time to do that?

@jbschlosser
Copy link
Copy Markdown
Contributor

Hey @antoniojkim, unfortunately the release branch cut for 1.12 was a decent while back. The fix will make it into 1.13 though.

miladm pushed a commit to miladm/pytorch that referenced this pull request Jun 27, 2022
What was happening is that when we have multiple learning rate schedulers, the order in which they are being initialized is not being taken into account. This is a problem if they were being initialized in sequential order (as one might intuitively do).

Each scheduler calls `step()` on initialization and sets the `lr` in its optimizer's `params_groups`. However, this means that step 0 will be using the `lr` that was set by the very last scheduler (in the case of initializing schedulers sequentially) instead of the first scheduler.

The fix in this PR, addresses the above bug by performing a call to the appropriate scheduler on initialization after decrementing the `last_epoch` values in order to keep them the same post-step. This will ensure that the correct scheduler is the one setting the `lr` values for the optimizer's `param_groups`
Pull Request resolved: pytorch#72856
Approved by: https://github.com/jbschlosser
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
What was happening is that when we have multiple learning rate schedulers, the order in which they are being initialized is not being taken into account. This is a problem if they were being initialized in sequential order (as one might intuitively do).

Each scheduler calls `step()` on initialization and sets the `lr` in its optimizer's `params_groups`. However, this means that step 0 will be using the `lr` that was set by the very last scheduler (in the case of initializing schedulers sequentially) instead of the first scheduler.

The fix in this PR, addresses the above bug by performing a call to the appropriate scheduler on initialization after decrementing the `last_epoch` values in order to keep them the same post-step. This will ensure that the correct scheduler is the one setting the `lr` values for the optimizer's `param_groups`
Pull Request resolved: pytorch#72856
Approved by: https://github.com/jbschlosser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source release notes: python_frontend python frontend release notes category Stale topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants