Skip to content

WIP: feat: LARS optimizer#88106

Draft
federicopozzi33 wants to merge 7 commits intopytorch:mainfrom
federicopozzi33:feature/lars-optimizer
Draft

WIP: feat: LARS optimizer#88106
federicopozzi33 wants to merge 7 commits intopytorch:mainfrom
federicopozzi33:feature/lars-optimizer

Conversation

@federicopozzi33
Copy link
Copy Markdown
Contributor

@federicopozzi33 federicopozzi33 commented Oct 31, 2022

Followup to #6323.

Addition of LARS optimizer.

  • LARS optimizer
  • Tests
  • Documentation
  • Multi-Tensor support
  • Extra params (e.g., maximize, differentiable, foreach)
  • .pyi

Reference implementations: [1]

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Oct 31, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88106

Note: Links to docs will display an error until the docs builds have been completed.

❌ 15 New Failures, 5 Unrelated Failures

As of commit c71c362 with merge base 64077ce (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Oct 31, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

@federicopozzi33 federicopozzi33 force-pushed the feature/lars-optimizer branch 3 times, most recently from 81c21fc to 4894875 Compare November 6, 2022 10:00
@federicopozzi33
Copy link
Copy Markdown
Contributor Author

Hi @datumbox,
I think you can start reviewing what I've done so far.

@datumbox
Copy link
Copy Markdown
Contributor

@federicopozzi33 Thanks for the ping! I'm a bit swamped this week. Can I get back to you by the middle of next one?

Copy link
Copy Markdown
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @federicopozzi33. I've added a few comments/questions below, let me know your thoughts. Also is there a reference implementation we should be crediting here?

@frgfm I was wondering if you have the chance to have also a look given you have previously implemented it to ensure the validity.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Why not set a default value similar to other optimizers?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, 1e-3 is being used on the others 👍

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to set a default value for the learning rate? I didn't because SGD, Adam, and other optimizers don't have one.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adam has lr=1e-3 as @frgfm mentioned. I agree it would be good to have a default here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func is not declared if scripting. Is this implementation complete?

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the single_tensor implementation is complete. The multi_tensor one is still missing.

I'm temporarily raising an exception when scripting.

EDIT: I looked at the SGD implementation more carefully, and it seems that just scripting + foreach is not supported. So, I think that the original implementation was OK.

Comment on lines 45 to 46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why not put the trust_coefficient and eps in the defaults above?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 85 to 86
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to grab these from the group similar to other params here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

@frgfm
Copy link
Copy Markdown
Contributor

frgfm commented Nov 28, 2022

@frgfm I was wondering if you have the chance to have also a look given you have previously implemented it to ensure the validity.

Sure, I'll take a look by tomorrow!

Copy link
Copy Markdown
Contributor

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR 🙏 I understand it's a work in progress so take my comments without too much final consideration!

I'll need to take a look at the latest implementation of SGD with the single tensor & multi tensor functional API to make a comprehensive review. Let me know what you think

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, 1e-3 is being used on the others 👍

Comment on lines 16 to 24
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest keeping the same arg order as other optimizers as well

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed SGD's order; Adam's is quite different (due to different parameters too).

Comment on lines 63 to 75
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open suggestion but perhaps list comprehensions would be better?

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a list comprehension I would need to loop more than once. In this way, I loop only once, which is more efficient.

Comment on lines 90 to 93
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability, I suggest moving that before the functional API call

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would be the same: in the functional call, momentum_buffer_list is updated in place.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it would be after

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the global LR is not applied to the momentum part to the best of my understanding

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will check again the paper and some online implementations soon.

Comment on lines 151 to 156
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the eps is required: if p_norm is zero, then the whole thing is zero, if g_norm is zero, then the basic update term is zero (same effect as using local LR = 1 if any of those is zero, but we avoid the eps imprecision :))

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's required to avoid zero division when both terms are zero.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_norm and g_norm are banned from being 0 with the conditional, no?

That said, I am not sure why lightning included an eps.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if p_norm * g_norm > 0:
lars_lr = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + eps)
lars_lr = trust_coefficient * p_norm / max(g_norm + p_norm * weight_decay, eps)

This would avoid unnecessarily biasing the denominator.

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code again, and I think you're right.

@federicopozzi33 federicopozzi33 requested review from datumbox and frgfm and removed request for datumbox and frgfm December 10, 2022 18:16
@datumbox
Copy link
Copy Markdown
Contributor

@federicopozzi33 there are still some conflicts, could you resolve?

@github-actions github-actions bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Dec 18, 2022
group.setdefault("maximize", False)
group.setdefault("differentiable", False)

def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice that has_sparse_grad is excluded from this implementation so far. What are your plans with the variable?

Wanted to give you the heads up that we are planning on deprecating has_sparse_grad, because our main use of it is to determine whether we could use the foreach implementation or not for it. Instead of barring all params from going through the foreach implementation because of one sparse grad, the new ideal is to group the tensors by device and to put everything that cannot be foreach'd into the cpu bucket, so that we can maximally enjoy the speed of foreach.

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I skipped it. My idea was to complete at least the basic implementation, then gradually implement more "advanced" parameters, like foreach.

Does it sound good to you?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I wanted to give you more context regarding the deprecation of has_sparse_grad (as we plan to remove it for all optims as well)

Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add default 1e-3 here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless with LARS it is recommended to decide on an initial LR? From the paper, it looks like they specifically use bigger LRs.

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they experiment way bigger LR in the paper. That's why I didn't initially set a default value.

Here, LR=1.0 is set, whereas no default value is set here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, maybe we should just up the default to 1.0, with a comment that larger LRs are used with LARS


LARS.__doc__ = r"""Implements LARS algorithm.

For further details regarding the algorithm we refer to `Large Batch Training of Convolutional Networks`_.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would want to include the latex version of the algorithm here. I just took some time to read the arxiv linked for the LARS introduction. Is it just me or are there several key typos in their algorithm (e.g., the trust coefficient should be used in the local lr update but it isn't?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed it too.

Do you think I should just copy the algorithm they reported in the paper?

Comment on lines +192 to +193
p_norm = torch.norm(param.data)
g_norm = torch.norm(d_p.data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
p_norm = torch.norm(param.data)
g_norm = torch.norm(d_p.data)
p_norm = torch.norm(param)
g_norm = torch.norm(d_p)

.data is deprecated + should not be used


if weight_decay != 0:
# LARS scaling:
if p_norm * g_norm > 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if p_norm * g_norm > 0:
if p_norm != 0 and g_norm != 0:

This should be a cheaper check


if weight_decay != 0:
# LARS scaling:
if p_norm * g_norm > 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if p_norm * g_norm > 0:
if p_norm != 0 and g_norm != 0:

This should be a cheaper check

Copy link
Copy Markdown
Contributor

@janeyx99 janeyx99 May 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the check is probably unnecessary (reason being that they'll be floats and comparing floats to 0 is often silly). I would vouch for getting rid of this check entirely and keeping the eps term.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the two norm computations inside the if condition, since they are used only there.

lars_lr = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + eps)

d_p = d_p.add(param, alpha=weight_decay)
d_p.mul_(lars_lr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
d_p.mul_(lars_lr)
d_p = d_p * lars_lr

avoid inplace updates

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. There're some other in-place updates. Should I change them too?

I'm referring to:

buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

and

param.add_(d_p, alpha=-lr)

optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)

def test_lars(self):
# ASK: What's the reason behind two identical calls? (See SGD tests)
Copy link
Copy Markdown
Contributor

@janeyx99 janeyx99 May 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which two are you referring to here? The optim tests are certainly up for a refactor, so there might just be actual redundancies we can get rid of

Copy link
Copy Markdown
Contributor Author

@federicopozzi33 federicopozzi33 May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm refererring to the first two test cases defined here

def test_sgd(self):

@janeyx99
Copy link
Copy Markdown
Contributor

janeyx99 commented May 1, 2023

@federicopozzi33 I've gone over with a more thorough review. Generally, the math looks consistent :D

I've also read the paper and realized a bit anticlimactically that the entire "layer-wise" portion of the optimizer is just a local multiplier that seeks to balance out the weight norm to grad norm ratio. Thankfully that means that this optimizer isn't too crazy to implement. 😛

@federicopozzi33
Copy link
Copy Markdown
Contributor Author

federicopozzi33 commented May 2, 2023

@federicopozzi33 I've gone over with a more thorough review. Generally, the math looks consistent :D

I've also read the paper and realized a bit anticlimactically that the entire "layer-wise" portion of the optimizer is just a local multiplier that seeks to balance out the weight norm to grad norm ratio. Thankfully that means that this optimizer isn't too crazy to implement. 😛

Hi @janeyx99,
thanks for the review. I will go through your comments in the next few days.

@federicopozzi33 federicopozzi33 requested a review from janeyx99 May 21, 2023 16:25
@janeyx99
Copy link
Copy Markdown
Contributor

@federicopozzi33 I wanted to update you on our side--sorry for the delay. We are currently in the middle of planning and discussing how we want to offer and incorporate new optimizers + optimizer features, including a test revamp that should enable safer test coverage for new optims. Thus, let's put a pause on this PR until we reach a consensus on conclusions from our side.

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 27, 2023
@github-actions github-actions bot closed this Sep 26, 2023
@datumbox
Copy link
Copy Markdown
Contributor

@janeyx99 @federicopozzi33 Too bad to see this optimiser didn't make it in PyTorch, despite it's popularity. Any chance we can kick this off again?

@federicopozzi33
Copy link
Copy Markdown
Contributor Author

federicopozzi33 commented Oct 1, 2023

Hi @datumbox,
I'm waiting for instructions by @janeyx99 to get back on it.

See the latest messages for more details.

@janeyx99 janeyx99 reopened this Oct 1, 2023
@janeyx99 janeyx99 removed the Stale label Oct 1, 2023
@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Nov 30, 2023
@janeyx99 janeyx99 removed the Stale label Nov 30, 2023
@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 30, 2024
@janeyx99 janeyx99 added no-stale and removed Stale labels Jan 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: optimizer Related to torch.optim no-stale open source release notes: nn release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants