Skip to content

Parametrizations depending on several inputs#58488

Closed
lezcano wants to merge 12 commits intopytorch:masterfrom
Quansight:lezcano/multiparam
Closed

Parametrizations depending on several inputs#58488
lezcano wants to merge 12 commits intopytorch:masterfrom
Quansight:lezcano/multiparam

Conversation

@lezcano
Copy link
Copy Markdown
Collaborator

@lezcano lezcano commented May 18, 2021

Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are torch.nn.utils.weight_norm and low rank parametrizations via the multiplication of a n x k tensor by a k x m tensor with k <= m, n.

Follows the plan outlined in #33344 (comment). A short summary of the idea is: we call right_inverse when registering a parametrization to generate the tensors that we are going to save. If right_inverse returns a sequence of tensors, then we save them as original0, original1... If it returns a Tensor or a sequence of length 1, we save it as original.

We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one.

There were a number of choices in the implementation:

If the right_inverse returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as:

class Sum(nn.Module):
  def forward(self, X, Y):
    return X + Y
  def right_inverse(Z):
    return Z, torch.zeros_like(Z)

rather than having to unpack manually a list or a tuple within the forward function.

At the moment the errors are a bit all over the place. This is to avoid having to check some properties of forward and right_inverse when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible.

The invariants are the following:

  1. The following code should be well-formed
X = module.weight
Y = param.right_inverse(X)
assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence)
Z = param(Y) if isisntance(Y, Tensor) else param(*Y)

in other words, if Y is a Sequence of Tensors (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters param.forward accepts.

  1. Always: X.dtype == Z.dtype and X.shape == Z.shape. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor.
  2. If it's one-to-one: X.dtype == Y.dtype. This is to be able to do X.set_(Y) so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse X and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of spectral_norm and weight_norm does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky.

I'm still missing to go over the formatting of the documentation, I'll do that tomorrow.

@lezcano lezcano added the module: nn Related to torch.nn label May 18, 2021
@lezcano lezcano requested review from albanD and soulitzer May 18, 2021 16:17
@lezcano lezcano requested a review from jbschlosser as a code owner May 18, 2021 16:17
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented May 18, 2021

💊 CI failures summary and remediations

As of commit ccc0a87 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 18, 2021

Codecov Report

Merging #58488 (1019503) into master (821a975) will increase coverage by 0.00%.
The diff coverage is 91.04%.

❗ Current head 1019503 differs from pull request most recent head 31a1bcd. Consider uploading reports for the commit 31a1bcd to get more accurate results

@@           Coverage Diff           @@
##           master   #58488   +/-   ##
=======================================
  Coverage   76.46%   76.46%           
=======================================
  Files        1992     1992           
  Lines      199937   199980   +43     
=======================================
+ Hits       152879   152913   +34     
- Misses      47058    47067    +9     

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look, it looks quite good!

I think you can simplify a lot of code by making the params always a sequence as mentioned below.

Also I'm not sure to understand what prevents you from chaining two parametrizations that do have the following forward:

# param 1
def forward(self, p1, p2):
    return p1, p2
    
# param 2
def forward(self, p1, p2):
    return p1 + p2

Comment thread test/test_nn.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a code sample that reproduces this outside of reparametrization?
Is it because you have t.weight that is contiguous and has a t.weight.grad that is also contiguous.
But then you do t.weight.set_(other) and that one is not while t.weight.grad is still contiguous?

Comment thread test/test_nn.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A common trick we use here is to get value to always be a sequence by wrapping single Tensors into a tuple of size 1.

  • That simplifies the logic here as you don't need to special case for single Tensor
  • You can always handle original0 instead of original when there is a single Tensor
  • Calling into the forward can be done with common code as well.
  • If you really need to unpack it into a single Tensor (for the use of set_ for example), you can always save a boolean along with the sequence

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing that, but the treatment of these two cases is actually different enough to not be worth it---it clutters too much the implementation.

This comes from the fact that we reuse the tensor to avoid changing the id of the parameter and so on. Note, for example, that the exceptions raised in the if-else below are fundamentally different.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented May 19, 2021

From offline: it's not possible to register parametrizations with several outputs, as each of these outputs would need to be a property on its own.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2021
@lezcano lezcano force-pushed the lezcano/multiparam branch 2 times, most recently from 806919e to e8e586f Compare June 7, 2021 14:04
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Jun 7, 2021

@albanD this is ready for another round of reviews.

Now the errors are all detected and thrown when registering a new parametrization. This is very good, as we test that the forward and the right_inverse can be run.

There used to be this very annoying problem before that, if you had made a mistake when writing your parametrization and you accessed self.foo and foo did not exist, register_parametrization would succeed. But then, when you called module.weight, since the parametrizations are registered as properties, the parametrization would raise an AttributeError. This error would be captured by Module.__getattr__ and the final error that you would get was "module does not have an attribute called "weight"", which was very confusing.

This early testing should get rid of this problem and, in general, should error early, which is nice.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You commited some submodule updates by mistake I think.

Looks good to me. Mainly some small naming updates.

Comment thread test/test_nn.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a sequence

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not name them original and new in your code sample?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was more difficult to track all the names, and I was not able to come up with meaningful names. X, Y, Z made for more concise lines

Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to enforce the shape equality for X and Y. In particular, set_ will work just fine with another Tensor of a different size.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary, but I thought that, in the same way that we disallow to change the dtype for users not to encounter weird behaviours, we should also disallow changing the shape. That being said, I can remove this constraint.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought, let's not be too conservative. I'm removing this constraint, as I could imagine it being useful.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you would prefer to match exactly the two properties: requires_grad and Parameter or not based on the original.
Users can set requires_grad=False on Parameters, or create attributes that require gradients that are not Parameters.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we can make that assumption here.
You most likely want to double check things here, especially below for the number of Tensors it returns when a Sequence is involved.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to be?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the one above, I understand that you don't want to use original/new in the sample to keep names short.
But I'm not sure why X and Z are switched here. It would be better to have consistent naming throughout here.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Jun 9, 2021

This is ready for another review @albanD

The invariants preserved and tested for are the following:

  • model.weight should never change of dtype or shape. In other words, the parametrization.forward of every parametrisation registered on a tensor should have the same dtype and shape as the original unparametrised tensor.
  • if registered on an uparametrised tensor, right_inverse should return a tensor or a sequence of tensors
    • If it returns one tensor, the returned tensor should have the same dtype of the original tensor (to be able to use set_)
  • otherwise (i.e. if registered on a parametrised tensor) right_inverse should return a tensor of the same shape and dtype as model.weight

I believe that this is a minimal way of describe the set of invariants that we need.

I hope that I have not missed any of them in the testing nor in the exceptions in the code. It was quite tricky to convince myself that everything was correct...

@lezcano lezcano force-pushed the lezcano/multiparam branch from d0b0b26 to ba56b21 Compare June 11, 2021 11:06
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
only small nits

Comment thread torch/nn/utils/parametrize.py Outdated
raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
f"Got {type(new).__name__}")

# If it is a sequence of one tensor, we unpack it
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to do that? Also this is not documented ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was that in the 1 input case we can reuse the original tensor, which is good. Now, this is such a weird case... What do you think we should do?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If people return (foo,) instead of foo, I'm sure they have a good reason to do it. So I would lean towards removing this.
Also the way it works right now, you will pass "foo" to their forward and not "(foo,)" which is not consistent.

Comment thread torch/nn/utils/parametrize.py Outdated
"""
# All the exceptions in this function should almost never throw.
# They could throw if, for example, right_inverse function does not return the same dtype
# for every input, which should most likely be caused by a bug in the code
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe this could be clearer if we say that "right_inverse function returns a different dtype when given a different input"

Comment thread torch/nn/utils/parametrize.py
Use collections.abc
Change the handling of a sequence of one tensor
Reword some bits
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Jun 11, 2021

@albanD corrected!
I also added a test for the case of a sequence of 1 tensor.

Comment thread torch/nn/utils/parametrize.py Outdated
def forward(self) -> Tensor:
# Unpack the originals for the first parametrization
if self.ntensors == 1:
if self.is_tensor == 1:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are bool checks no? You don't want the == 1 ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooopsss. Sorry for that and thanks for the catch!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD merged this pull request in 061e71b.

Comment thread test/test_nn.py
weight_data = weight.data.clone()
with torch.no_grad():
weight.set_(weight_data)
weight.copy_(weight_data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the only place I can think of that relates to the HUD CI failure

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Jun 14, 2021

This seems to be breaking test_cudnn_weight_format in master: https://app.circleci.com/pipelines/github/pytorch/pytorch/336608/workflows/717d1325-5ae0-4b73-abf8-bced763e9dbb/jobs/14137264
I have no idea how that could be, but reverting for now.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been reverted by 5c1d17e.

facebook-github-bot pushed a commit that referenced this pull request Jun 25, 2021
Summary:
Resubmit of #58488

There was a line that had been changed in `test_nn.py` as caught in #58488 (comment)

I reverted that line, which should never have been changed. I reckon that should solve the issue.

Pull Request resolved: #60530

Reviewed By: ngimel

Differential Revision: D29329865

Pulled By: albanD

fbshipit-source-id: 8dfd0cd968fe26a3924dae7ca366af2c8a8639b3
asuhan pushed a commit to asuhan/pytorch that referenced this pull request Jun 28, 2021
Summary:
Resubmit of pytorch#58488

There was a line that had been changed in `test_nn.py` as caught in pytorch#58488 (comment)

I reverted that line, which should never have been changed. I reckon that should solve the issue.

Pull Request resolved: pytorch#60530

Reviewed By: ngimel

Differential Revision: D29329865

Pulled By: albanD

fbshipit-source-id: 8dfd0cd968fe26a3924dae7ca366af2c8a8639b3
asuhan pushed a commit that referenced this pull request Jun 30, 2021
Summary:
Resubmit of #58488

There was a line that had been changed in `test_nn.py` as caught in #58488 (comment)

I reverted that line, which should never have been changed. I reckon that should solve the issue.

Pull Request resolved: #60530

Reviewed By: ngimel

Differential Revision: D29329865

Pulled By: albanD

fbshipit-source-id: 8dfd0cd968fe26a3924dae7ca366af2c8a8639b3
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are `torch.nn.utils.weight_norm` and low rank parametrizations via the multiplication of a `n x k`  tensor by a `k x m` tensor with `k <= m, n`.

Follows the plan outlined in pytorch#33344 (comment). A short summary of the idea is: we call `right_inverse` when registering a parametrization to generate the tensors that we are going to save. If `right_inverse` returns a sequence of tensors, then we save them as `original0`, `original1`...  If it returns a `Tensor` or a sequence of length 1, we save it as `original`.

We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one.

There were a number of choices in the implementation:

If the `right_inverse` returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as:
```python
class Sum(nn.Module):
  def forward(self, X, Y):
    return X + Y
  def right_inverse(Z):
    return Z, torch.zeros_like(Z)
```
rather than having to unpack manually a list or a tuple within the `forward` function.

At the moment the errors are a bit all over the place. This is to avoid having to check some properties of `forward` and `right_inverse` when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible.

The invariants are the following:
1. The following code should be well-formed
```python
X = module.weight
Y = param.right_inverse(X)
assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence)
Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
```
in other words, if `Y` is a `Sequence` of `Tensor`s (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters `param.forward` accepts.

2. Always: `X.dtype == Z.dtype and X.shape == Z.shape`. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor.
3. If it's one-to-one: `X.dtype == Y.dtype`. This is to be able to do `X.set_(Y)` so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse `X` and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of `spectral_norm` and `weight_norm` does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky.

I'm still missing to go over the formatting of the documentation, I'll do that tomorrow.

Pull Request resolved: pytorch#58488

Reviewed By: soulitzer

Differential Revision: D29100708

Pulled By: albanD

fbshipit-source-id: b9e91f439cf6b5b54d5fa210ec97c889efb9da38
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Resubmit of pytorch#58488

There was a line that had been changed in `test_nn.py` as caught in pytorch#58488 (comment)

I reverted that line, which should never have been changed. I reckon that should solve the issue.

Pull Request resolved: pytorch#60530

Reviewed By: ngimel

Differential Revision: D29329865

Pulled By: albanD

fbshipit-source-id: 8dfd0cd968fe26a3924dae7ca366af2c8a8639b3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants