Skip to content

feat: add PolynomialLR scheduler#82769

Closed
federicopozzi33 wants to merge 6 commits intopytorch:masterfrom
federicopozzi33:feature/4438-poly-lr-scheduler
Closed

feat: add PolynomialLR scheduler#82769
federicopozzi33 wants to merge 6 commits intopytorch:masterfrom
federicopozzi33:feature/4438-poly-lr-scheduler

Conversation

@federicopozzi33
Copy link
Copy Markdown
Contributor

@federicopozzi33 federicopozzi33 commented Aug 3, 2022

Description

Add PolynomialLR scheduler.

Issue

Closes #79511.

Testing

I added tests for PolynomialLR.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Aug 3, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 20f6c84 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@federicopozzi33
Copy link
Copy Markdown
Contributor Author

federicopozzi33 commented Aug 3, 2022

@datumbox

First draft. Documentation should be quite ok, but building documentation phase is missing.

Tests can be improved (combined versions are missing, e.g. PolynomialLR + StepLR).

I tried to generate the .pyi for optimizer.py but the result I obtained is quite different from the existing version. This is what I did, using mypy:

stubgen torch/optim/lr_scheduler.py

Generated file:

class PolynomialLR(_LRScheduler):
    total_iters: Incomplete
    min_lrs: Incomplete
    power: Incomplete
    def __init__(self, optimizer, total_iters: int = ..., min_lr: float = ..., power: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ...
    def get_lr(self): ...

class LinearLR(_LRScheduler):
    start_factor: Incomplete
    end_factor: Incomplete
    total_iters: Incomplete
    def __init__(self, optimizer, start_factor=..., end_factor: float = ..., total_iters: int = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ...
    def get_lr(self): ...

Existing optimizer.pyi:

class LinearLR(_LRScheduler):
    start_factor: float = ...
    end_factor: float = ...
    total_iters: int = ...
    def __init__(self, optimizer: Optimizer, start_factor: float=..., end_factor: float= ..., total_iters: int= ..., last_epoch: int= ..., verbose: bool = ...) -> None: ...

Note2: I confused torchvision issue with torch one, so the issue number in the branch name is wrong. Should I delete the branch?

Copy link
Copy Markdown
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@federicopozzi33 Thanks a lot for the PR. As discussed at pytorch/vision#4438 this is definitely something that TorchVision can use.

I've added a few comments but overall it looks good. Please let me know what you think.

Documentation should be quite ok, but building documentation phase is missing

I had to manually approve running your tests because that was your first PR. It's running the jobs now.

I confused torchvision issue with torch one, so the issue number in the branch name is wrong. Should I delete the branch?

I don't think PyTorch core has such a strict process for the naming conventions on your repo/branch. If renaming is needed we can do after the reviews.

I tried to generate the .pyi for optimizer.py but the result I obtained is quite different from the existing version.

I'm not sure about this one, let's leave it last to fix. Once the PR is ready, I'll ping some Core devs who could clarify these for us. :)

@federicopozzi33 federicopozzi33 force-pushed the feature/4438-poly-lr-scheduler branch 3 times, most recently from 5d44650 to a769f92 Compare August 6, 2022 00:31
Copy link
Copy Markdown
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall it looks good. I've added only one minor nit comment below.

The question around the generation of optimizer.pyi remains, but we can ask core devs how to handle it. I've restarted the unit-tests so that we can get a signal from the CI. If all tests pass, from my perspective, we should mark the PR as non draft and ask for the review of Alban/Joel.

@datumbox datumbox marked this pull request as ready for review August 8, 2022 15:19
@datumbox datumbox requested a review from albanD as a code owner August 8, 2022 15:19
@federicopozzi33
Copy link
Copy Markdown
Contributor Author

I think overall it looks good. I've added only one minor nit comment below.

The question around the generation of optimizer.pyi remains, but we can ask core devs how to handle it. I've restarted the unit-tests so that we can get a signal from the CI. If all tests pass, from my perspective, we should mark the PR as non draft and ask for the review of Alban/Joel.

Some tests are still failing, but it seems that they are not related to my changes... can you check you too, please?

@datumbox datumbox requested a review from jbschlosser August 8, 2022 15:20
@datumbox
Copy link
Copy Markdown
Contributor

datumbox commented Aug 8, 2022

@federicopozzi33 LGTM! The failing tests are unrelated.

I think now we should be good to ask for the feedback of @albanD and @jbschlosser. Please let us know your thoughts and your recommendation for the generation of optimizer.pyi.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo in the doc. but LGTM otherwise.

Thanks @datumbox for doing the early review!

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Aug 9, 2022

The .pyi files are written by hand if they are checked into the git repo. That's why it doesn't match what mypy would generate.

@federicopozzi33 federicopozzi33 requested a review from albanD August 10, 2022 12:16
Copy link
Copy Markdown
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @federicopozzi33.

I'll leave it to @albanD and @jbschlosser to approve and merge. Note that I spoke with them and they told me it's quite hectic for them this week, so we might need to be a bit patient. Don't worry I'll follow up with them regularly to get this over the line.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Aug 10, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feature/4438-poly-lr-scheduler onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout feature/4438-poly-lr-scheduler && git pull --rebase)

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Aug 10, 2022

@pytorchbot merge -g

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Copy Markdown
Contributor

Hey @federicopozzi33.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 11, 2022
Summary:
### Description
<!-- What did you change and why was it needed? -->

Add PolynomialLR scheduler.

### Issue
Closes #79511.

### Testing
I added tests for PolynomialLR.

Pull Request resolved: #82769
Approved by: https://github.com/datumbox

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f8a10a7f79b574422dc5477e86284c7539790bde

Reviewed By: seemethere

Differential Revision: D38600088

Pulled By: seemethere

fbshipit-source-id: 6f93f47efda4072284c3049f21df0f70690100fe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Suggest for a polynominal lr_scheduler

6 participants