Skip to content

To fix the chainability at epoch zero for some schedulers#63457

Closed
iramazanli wants to merge 1 commit intopytorch:masterfrom
iramazanli:group_lr_base_lr
Closed

To fix the chainability at epoch zero for some schedulers#63457
iramazanli wants to merge 1 commit intopytorch:masterfrom
iramazanli:group_lr_base_lr

Conversation

@iramazanli
Copy link
Contributor

@iramazanli iramazanli commented Aug 18, 2021

It has been discussed in the #60836 (comment) that we have observed an obstacle to chain some type of learning rate schedulers. In particular we observed

  • some of the learning rate schedulers returns initial learning rates at epoch 0 as
       return self.base_lrs`
  • This can be a problem when two schedulers called as chained as
     scheduler1.step()
     scheduler2.step()

in particular, we completely ignore the effect of scheduler1 at epoch 0. This could not be an issue if at epoch 0, scheduler1 was ineffective as in many schedulers, however for schedulers as WarmUp Schedulers, where at epoch 0 schedulers multiplicative value is smaller than 1 this could lead to undesired behaviors.

The following code snippet illustrates the problem better

Reproducing the bug

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import WarmUpLR, ExponentialLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 1.0)
scheduler1 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")
scheduler2 = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(10):
     print(epoch, scheduler2.get_last_lr()[0])
     optimizer.step()
     scheduler1.step()
     scheduler2.step()

Current Result

0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 5.904900000000001
6 5.314410000000001
7 4.782969000000001
8 4.304672100000001
9 3.874204890000001

Expected Result

0 1.0
1 0.9
2 0.81
3 0.7290000000000001
4 0.6561000000000001
5 0.5904900000000001
6 0.5314410000000001
7 0.4782969000000001
8 0.4304672100000001
9 0.3874204890000001

Partially resolves pytorch/vision#4281

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 18, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit ce80523 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@codecov
Copy link

codecov bot commented Aug 18, 2021

Codecov Report

Merging #63457 (ce80523) into master (4a390a5) will decrease coverage by 0.04%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #63457      +/-   ##
==========================================
- Coverage   75.56%   75.51%   -0.05%     
==========================================
  Files        2118     2118              
  Lines      212263   212291      +28     
==========================================
- Hits       160399   160316      -83     
- Misses      51864    51975     +111     

@iramazanli iramazanli requested review from datumbox and fmassa August 19, 2021 04:48
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @iramazanli for fixing this so quickly.

I tested your patch on latest nightly with a slightly modified loop:

for epoch in range(10):
     print(epoch, scheduler2.get_lr())
     optimizer.step()
     scheduler1.step()
     scheduler2.step()

And I get the expected result:

0 [0.1]
1 [0.08100000000000002]
2 [0.07290000000000002]
3 [0.06561000000000002]
4 [0.05904900000000002]
5 [0.5314410000000002]
6 [0.47829690000000014]
7 [0.43046721000000016]
8 [0.38742048900000015]
9 [0.34867844010000015]

Which is the combined effect of both schedulers.

@facebook-github-bot
Copy link
Contributor

@iramazanli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for fixing this!

@iramazanli
Copy link
Contributor Author

LGTM, thanks @iramazanli for fixing this so quickly.

I tested your patch on latest nightly with a slightly modified loop:

for epoch in range(10):
     print(epoch, scheduler2.get_lr())
     optimizer.step()
     scheduler1.step()
     scheduler2.step()

And I get the expected result:

0 [0.1]
1 [0.08100000000000002]
2 [0.07290000000000002]
3 [0.06561000000000002]
4 [0.05904900000000002]
5 [0.5314410000000002]
6 [0.47829690000000014]
7 [0.43046721000000016]
8 [0.38742048900000015]
9 [0.34867844010000015]

Which is the combined effect of both schedulers.

That's amazing! lets merge this PR then :)

@facebook-github-bot
Copy link
Contributor

@iramazanli merged this pull request in e7c4988.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Update reference scripts to use the "Batteries Included" utils

4 participants