Skip to content

Parametrization Functionality#33344

Closed
lezcano wants to merge 2 commits intopytorch:masterfrom
lezcano:master
Closed

Parametrization Functionality#33344
lezcano wants to merge 2 commits intopytorch:masterfrom
lezcano:master

Conversation

@lezcano
Copy link
Copy Markdown
Collaborator

@lezcano lezcano commented Feb 14, 2020

Provides the implementation for feature request issue #28937.

Adds the Parametrization functionality and implements Pruning on top of it [UPDATE: it doesn't implement Pruning].
It adds the auto mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion #7313. In particular, it assumes that the user will not manually change the updated parameters between the call to backwards() and the optimizer.step(). If they do so, they would need to manually call the .invalidate() function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the .cuda() and .to and related functions.

As described in #7313, this could be used, to implement in a cleaner way the weight_norm and spectral_norm functions. It also allows, as described in #28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):

  • More thorough test
  • Documentation

Resolves #28937

@albanD

@lezcano lezcano requested a review from apaszke as a code owner February 14, 2020 14:32
@lezcano lezcano changed the title Parametrizations implemented with some minimal testing. Parametrization Functionality Feb 14, 2020
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Feb 14, 2020

💊 CI failures summary and remediations

As of commit ff0aed3 (more details on the Dr. CI page):



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Mar 04 16:47:57 urllib.error.HTTPError: HTTP Error 403: Forbidden
Mar 04 16:47:57   File "/opt/conda/lib/python3.6/urllib/request.py", line 532, in open
Mar 04 16:47:57     response = meth(req, response)
Mar 04 16:47:57   File "/opt/conda/lib/python3.6/urllib/request.py", line 642, in http_response
Mar 04 16:47:57     'http', request, response, code, msg, hdrs)
Mar 04 16:47:57   File "/opt/conda/lib/python3.6/urllib/request.py", line 570, in error
Mar 04 16:47:57     return self._call_chain(*args)
Mar 04 16:47:57   File "/opt/conda/lib/python3.6/urllib/request.py", line 504, in _call_chain
Mar 04 16:47:57     result = func(*args)
Mar 04 16:47:57   File "/opt/conda/lib/python3.6/urllib/request.py", line 650, in http_error_default
Mar 04 16:47:57     raise HTTPError(req.full_url, code, msg, hdrs, fp)
Mar 04 16:47:57 urllib.error.HTTPError: HTTP Error 403: Forbidden
Mar 04 16:47:57 
Mar 04 16:47:57 ----------------------------------------------------------------------
Mar 04 16:47:57 Ran 1 test in 0.456s
Mar 04 16:47:57 
Mar 04 16:47:57 FAILED (errors=1)
Mar 04 16:47:57 Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist-data/MNIST/raw/train-images-idx3-ubyte.gz
Mar 04 16:47:58 
0it [00:00, ?it/s]
Mar 04 16:47:58 + cleanup
Mar 04 16:47:58 + retcode=1
Mar 04 16:47:58 + set +x

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Comment thread torch/nn/modules/module.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, otherwise the optimizer won't find the parameters to update. params in a torch.optim object are the IDs of the parameters.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. As such, the implemented behaviour of not allowing to choose leave_parametrized=True whenever the size of the parametrized tensor is not the same as that of the original tensor is quite reasonable.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ever None? Isn't it initialized to an empty OrderedDict?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right. I'll delete this check.

Comment thread torch/nn/utils/prune.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return module for same behavior as before in pruning

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to return a slightly more informative error when trying to remove pruning from an unpruned module ("Parameter '{}' of module {} has to be pruned before pruning can be removed" versus "No forward parametrizations found in tensor weight within module")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the warnings thrown.

Comment thread torch/nn/utils/prune.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would definitely split remove and undo into two separate functions given that they are two logically separate actions, as opposed to having: remove + leave pruned = remove reparametrization; remove + don't leave pruned = undo the pruning.

The one issue with undoing a reparametrization (and the reason why I chose not to implement an undo method for pruning) is that the definition of undoing is not clear for actions that can be called multiple times in an iterative fashion, like pruning. Will undo remove the whole history of pruning? or just the effect of the last pruning call? if the latter, then we would have to somehow keep track of all the history of mask changes, which instead I believe should be left to the user for specific use cases that require it. Another ambiguity around this is that once you make pruning permanent, you wouldn't be able to use undo anymore.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to separate this one into two functions, I think that it can make it more clear from a user interface perspective.

I completely agree that, given the specific requirements of pruning, the undo method is tricky to implement on it. On the other hand, I think that the undo method can be quite useful in general. I can think of at least one general use case for this in the context of optimization on manifolds, where setting and undoing parametrizations has a very clear geometric meaning.

Comment thread test/test_nn.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has a side effect that now the weight gets serialized in the state dict, whereas before it wasn't. This may cause compatibility issues for loading saved models.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. How should we go about this? In my opinion, it makes more sense to have weight as a buffer, and having it serialized with the whole model, given that tensor hooks (I don't know about module hooks) are not serialized.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original (unpruned) weight and the mask are everything you need to recreate the pruned weight by simply multiplying the two together [true in pruning, not sure about other use cases]. So, in the pruning case, there would be no need for the weight to also be stored in the state_dict for saving -- it would just be redundant info. And this would be another copy of the whole net if all layers are at least partially pruned.

In general, though, wouldn't most users want to make the parametrization permanent (through .remove) before serializing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is somehow redundant, but it is not entirely It is not entirely redundant when saving the module, since the tensor hooks, at the moment, are not serialized. As such, if the user wants to save and resume later the training, they will have to set the parametrization again. In this case, the saved buffer can be used as a check to see that the saving / loading has been performed correctly. My reasoning behind putting them as a buffer was that this way it is more orthogonal with the way pytorch works. For example, they are moved to GPU when the user chooses to move the whole model to GPU. Otherwise this would have to be done by hand.

I do not think that one would want to make parametrizations permanent before serializing in general. In the general case, a parametrization is a constraint on certain parameter. For this reason, one may want to save a checkpoint of a model and resume the training later. To be able to do so, one needs the parametrization and the original tensors to resume the optimization with the constraint.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't a module already store its parametrizations as hooks without the need to pass them in separately as args?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Corrected now.

Comment thread torch/nn/utils/prune.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users want to apply a parametrization that is logically compatible with pruning, they should be able to

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is outdated, I will change it. We do not support several forward parametrizations at the same time. I thought that these could be supported, but by the very nature of how forward hooks at a module level work, there is no obvious way to chain them. We allow to put several pruning parametrizations, because we are automatically merging them into the same PruningContainer.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes a lot of interoperability between parametrizations. This is fine for pruning, because the pruning module handles how to exactly combine successive pruning calls, but this may not be true in the general reparametrization case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is left to the user to implement these in a compatible way. It does not seem wise to impose this artificially, given that there are plenty of use cases where this could be useful. For example, in the case of wanting to put constraints on a matrix, you may choose to have a constraint that makes the matrix positive definite, and then combine this with a constraint that makes the matrix positive definite with determinant one. This makes these very flexible and easy to mix and match.

Comment thread torch/nn/utils/prune.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs an exception to raise :)

@albanD albanD self-requested a review February 18, 2020 18:25
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 18, 2020
Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes until the nn.Module situation is resolved.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Feb 28, 2020

After discussing with @apaszke , I think we want to do the following:

  • not modify the base nn.Module
  • set a custom attribute that replace the original Parameter on the given module with an associated function that will handle recomputing the value and the caching
  • The current caching may feel very magical. It would be better to have it as a context manager to make it super clear to the user when we start caching and when we invalidate the cache.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 10, 2020

I finished implementing the caching system with a context manager, and registering the Parametrizations as nn.Modules. The final implementation is very clean and in the spirit of Pytorch.

With this implementation, putting orthogonal constraints on a model using the Cayley map would be as easy as:

class Orthogonal(nn.Parametrization):
    def forward(self, t):
        A = t.triu(1)
        A = A - A.t()
        Id = torch.eye(A.size(0))
        return torch.solve(Id - X, Id + X).solution
x = nn.Linear(4, 4, bias=False)
x.register_parametrization(Orthogonal(), "weight")

One may put several parametrizations on one parameter. For example, one could prune the resulting orthogonal tensor by adding the line:

x.register_parametrization(RandomUnstructured(0.5), "weight")

If the parameter is shared, as is, for example, in an RNN, this can be handled by

nn.cached(model) as model:
    out = model(inputs_)
loss = model.loss_crit(out)

With this design, the modules pruning, weight_norm and spectral_norm would be implemented as normal modules, registering some auxiliary buffers in __init__ (e.g. the mask buffer in pruning or the u buffer in spectral_norm) and implementing a forward function.

It is missing to edit some functions in module to completely integrate these with nn.Module and of course, some proper testing.

I have reimplemented the pruning module in terms of this one to test its correctness, and the main functionality can be implemented in about 80 lines. I have not added this to the commit as the necessary changes in the pruning module would be breaking.

Edit: We could also implement the functionality of using it as a decorator of the forward method, similar to that of torch.no_grad:

class Model(nn.Module):
    @nn.cached
    def forward(self, input_):
        pass

This way the user would not need to wrap every forward call in a context manager.

What are your thoughts on this design @apaszke ?

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Mar 15, 2020

@lezcano have you been in contact with @albanD? We've discussed potential solutions, and I had previously understood that the two of you were working on this together. This is not really along the lines along which we've ended our discussion. I think that there is a hard requirement in that parametrization should require no changes to nn.Module. All the code should live in torch.nn.utils.parametrization (possibly with a slightly different name) and should help people write parametrizations using a variant of @property that additionally supports caching.

Note that there should be no methods to monkey-patch a module! If you want to use nn.Linear as a base, but parametrize the weights differently, then I'd say you should do something along the lines of this:

class MyParametrizedLinear(nn.Linear):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

  @parametrized(cached=True) # an optionally cached variant of @property
  def weight(self):
    return ...

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 16, 2020

The problem with that solution is that it does not solve the problem that this functionality was trying to address.

The idea of these parametrizations is that they should provide an easy way to build modules like BasePruningMethod, weight_norm, and spectral_norm as proposed in #7313. These modules require three things:

  • Having a caching system for the parametrization
  • They should be readily usable with pre-existing layers, as the methods mentioned above are
  • They should be able to register new buffers and parameters in the model
  • It would be desirable that parametrizations can be composed, e.g., to apply orthogonal constraints and then prune the resulting tensor

I think that I could modify the current implementation so that it is an external module. This module would add a property and a Parametrization into the module. The property would handle the caching and the Parametrization would be an nn.Module as it is now, that would handle the computation of the parametrization, and would register any auxiliary buffers and parameters on itself. This implementation would already greatly simplify the code of the three modules mentioned before, while allowing to compose these modules between them.

Edit: Never mind, the previous solution does not work, as one cannot add a property dynamically to an instance of a class, just to the whole class. The other "solution" that occurred to me was monkey-patching __getattr__, but that cannot be done in a per-instance basis either.

I am happy to consider different designs as, at the moment, I do not see how to implement this functionality without modifying nn.Module.

As the current implementation hints, this can be implemented with minimal changes to nn.Module, even less than those presented in here. One only needs a way to detect which nn.Modules are parametrizations, and use this detection in __getattr__ in the case the attribute is a module, returning module.cache in that case, where cache is an attribute of the parametrization which can be made into a property. This detection can be done via a thin wrapper of the nn.Module which adds the cache buffer.

Everything else would work as usual: parametrizations would be treated as modules in every method, e.g., .modules() or .children(). This, together with the method to register a Parametrization on a Parameter or a Parametrization (this takes care of composition in a natural way), and a function to remove remove the parametrization, both of which are already implemented, and maybe some auxiliary methods like .parametrizations() and a method that returns the active parametrization on a Parameter would have everything working.

Then again, if this is out of the table, I am happy to take suggestions on alternative designs, as all the designs proposed in the discussion in #7313 went along these lines.

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Mar 20, 2020

I'm not 100% sure if we should really be modifying live module objects while you apply pruning. I don't think that this should be a strict requirement, and maybe we should return new objects instead. Still, even if you insist on having all the properties you've mentioned, it's still doable as the snippet below shows:

import torch                                                
import torch.nn as nn                                       
                                                            
                                                            
class OffsetModule(nn.Module):                              
    def __init__(self, nfeatures):                          
        super().__init__()                                  
        self.offset = nn.Parameter(torch.randn(nfeatures))  
                                                            
    def forward(self, input):                               
        return input + self.offset                          
                                                            
                                                            
def constant_offset(module):                                
    nfeatures = len(module.offset)                          
    module.singleton_offset = nn.Parameter(module.offset[0])
    del module.offset                                       
                                                            
    def offset_generator(self):                             
        return self.singleton_offset.expand(nfeatures)      
                                                            
    param_cls = type('Constant' + module.__class__.__name__,
                     (module.__class__,),                   
                     {'offset': property(offset_generator)})
    module.__class__ = param_cls                            
                                                            
                                                            
x = torch.zeros(10)                                         
model = OffsetModule(10)                                    
print(model.state_dict())                                   
print(model(x))                                             
                                                            
constant_offset(model)                                      
print(model.state_dict())                                   
print(model(x))                                             
print(model.offset)                                         

This doesn't include caching, but now that you can do arbitrary things in the attribute accessors it shouldn't be very difficult to add it. Let me know if you have any more concerns.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 20, 2020

Yes, I think that a solution along this lines could work. The only potential problem I see is that, given that we are creating a new class, is that removing a parametrization once setted could be slightly tricky. Even then, I think it could be done by creating an object of the parent class and using the state_dict of the current parametrization to populate its members.

I'll implement everything following this sometime in the following days.

@lezcano lezcano changed the title Parametrization Functionality [WIP] Parametrization Functionality Mar 23, 2020
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 23, 2020

I just implemented everything outside of nn.Module. The current implementation ticks all the boxes described above. I have implemented the pruning functionality in terms of this module (not in this PR) and it works as expected. The same could be done with weight_norm and spectral_norm.

It also gets rid of the previous apply method. Now parametrizations are just normal modules. For example, a simple pruning module could be implemented as

class PruningMethod(Parametrization):
    def __init__(self):
        super(self, PruningMethod).__init__()
        self.register_buffer("mask", None)

    def init(self, t):
        # t is the tensor on which the parametrization has been applied
        super(self, PruningMethod).init(t)
        self.mask = self.compute_mask(t)

    def compute_mask(self, t)
        raise NotImplementedError()

The current mechanism in the pruning classes that merges two pruning classes into one by multiplying their masks when two pruning methods are applied to the same tensor can also be implemented in a natural way with the current implementation of Parametrization by making the init method slightly more complex.

The current implementation allows to put parametrizations on three objects: Parameters, buffers, and a parametrized object (recursive step). It also implements a caching mechanism that can be activated through a context manager or a decorator on the forward method.

This implementation also allows to register new parameters in the parametrizations, if necessary. This can be useful when implementing more complex parametrizations.

In order to implement these three classes in terms of this new Parametrization class, we would have to deprecate the existing ones, as these work in an intrinsically different way. Also, it would now be reasonable to put this machinery under torch/nn/modules, as these objects will now be regular modules. I'd be happy to do all these changes in a different PR if approved.

If the current design is given the OK, I will implement some further testing and I will finish writing the documentation that is missing.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 31, 2020

I have updated the PR implementing the idea of chaining parametrizations. This idea is the equivalent of composing functions in a functional way. This idea was implicit in the previous implementation through the register_parametrization method, but now it can be used from the outside.

I am writing a library on top of this functionality to test it, an these methods have allowed me to build all the cases that I had in mind. Of course, this implementation still supports the implementation of the pruning methods.

I have also deleted the init() method, which was supposed to be used to initialize inner parameters of the parametrization that depend on the tensor on which it is registered, as the mask buffer in a pruning method. Incidently, this encouraged working with half-initialized objects. If one need such objects in a parametrization, one should request them as parameters in the constructor.

For example, a basic pruning method would be implemented as:

class PruningMethod(Parametrization):
    def __init__(self):
        super(self, PruningMethod).__init__()
        self.register_buffer("mask", None)

    def forward(self, X):
        return self.mask * X

    def sample_mask_(self):
        if not self.is_registered():
          raise ValueError("...")
        self.mask = self.compute_mask(self.orig)

    def compute_mask(self, t)
        raise NotImplementedError()

or with a size parameter in the constructor and sampling a mask in the constructor. One may implement helper methods like the current l1_unstructured, random_unstructured and so that pass the correct size to the method in the constructor, register it, and initialize it calling sample_mask_. The current merging of masks could be implemented overriding the Parametrization.chain method.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Still, I think there are things we should work on improving, and it might still be possible to make the API much nicer IMO. I put a bunch of inline comments, let me know what you think. Thanks for taking the time to work on this and sorry for my slow replies 😕

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? It seems like a reasonable thing given that we allow registering the same parameter in multiple modules.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many parametrizations are in general stateful (consider the mask in a pruning method, or the u and v tensors within spectral_norm), so it should not be allowed to register them on several tensors. This also allows for the current chain-like implementation (more on in a comment below)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well but what if you really want to use the same parametrization from multiple modules? I don't see why should we forbid that. Parameters are stateful in their own ways too (e.g. .grad), but they still can be shared between modules.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see any application for a shared parametrization. If a user really wanted to do so, they could create a new parametrization that shares the parameters with an existing one.

Having a parametrization only accepting a parameter allows writing parametrizations that can evaluate their current state calling self.evaluate(), and use this value to update their inner buffers, which is quite useful.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should register the cache as a buffer. I'd like to avoid having those show up in state_dict() because they're not really meaningfully part of the module's state.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_param seems like someone would be naming a "weight parameter". I think we should be a bit more verbose and call it a weight_parametrization.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly don't love the mutable chaining syntax. My concerns are:

  • It's generally better to avoid mutable state unless we really need it. It would be more natural to me to have it return a new parametrization that you can later register.
  • The current syntax doesn't make it clear which parametrization gets evaluated first.
  • How does it work if my parametrization has multiple parameters? Imagine a low-rank decomposition of the weight matrix as an example, as a cross product of two vectors.

Anyway, compare how mich clearer does this reads:

reagister_parametrization(module, "weight", p1 >> p2)

You can clearly see the order in which the parametrization pipeline is applied, and the operator can be made immutable. It doesn't solve the multiple parameter case though.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you have multiple parameters, you can register them in a ModuleList and then implement the forward function as:

    def forward(self, Xs):
        return [p.evaluate(X) for p, X in zip(self.parametrizations, Xs)]

Funnily enough, I had already implemented a version of the lowrank example that you propose there, but with an SVD decomposition. In the case you propose, the parametrization should register an auxiliary parameter the size of the original vector, as the parametrization has more parameters than the original vector (it's a submersion).

The first and second points are very valid though. My reasoning to go with this design was the following: This chaining mechanism allows you, at any point, to evaluate the current value of the parametrization. I think of a parametrization as a RAII object over a buffer or a parameter, becoming the one that manages it. This way, if you have a chain of two parametrizations p1 from R^n to R^k and p2 from R^k to R^d over certain parameter t in R^n. In this setting, p2 does not have to know that it is acting on a parametrized object, it can always just ask for the value of the parameter, and it will be given p1(t), which lives in the correct space R^k, even though the original object doesn't.

This design also allows one to implement easily optimizations like the one currently being implemented in the pruning module, where if two pruning modules are put one after the other, their masks are merged into one by multiplying them. This could be done by simply overloading the chain method.

This is related to the point you made below of flattening this chain of parametrizations into something similar to a Sequential module. The problem with that is that each parametrization does not have a way to evaluate the current state of the tensor (which might be necessary to initialize the state of the parametrization, like the mask) and does not know anything about the previous parametrizations either (no optimization like that of pruning is possible).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you have multiple parameters, you can register them in a ModuleList and then implement the forward function as:

Sorry, but I don't see how this applies to any of my points. Your example applies multiple independent parametrizations to multiple independent parameters, while I bring up issues about chaining parametrization on a single parameter.

Funnily enough, I had already implemented a version of the lowrank example that you propose there, but with an SVD decomposition. In the case you propose, the parametrization should register an auxiliary parameter the size of the original vector, as the parametrization has more parameters than the original vector (it's a submersion).

What does that mean? Where does it have to register this "auxiliary parameter"? I think all of the state necessary to construct a parametrization should be held entirely in the Parametrization object. It should never leak its own details into other modules, but that means that you cannot assume that there's a single underlying tensor you'll be processing.

The first and second points are very valid though. My reasoning to go with this design was the following: ...

While I see what you're saying, I don't understand how does it apply to my first two points. Chaining parametrizations doesn't require them to be stateful (just think of them as composition of stateless functions!), and the >> syntax can be overloaded too. The only difference is that you'd be implementing __rshift__ instead of chain in a subclass!

This design also allows one to implement easily optimizations like the one currently being implemented in the pruning module, ...

I don't see how my proposed design would prevent this. You can inspect the parametrizations in __rshift__ and can return different specialized composed forms if that can be done.

Copy link
Copy Markdown
Contributor

@apaszke apaszke Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the point you made below of flattening this chain of parametrizations into something similar to a Sequential module. The problem with that is that each parametrization does not have a way to evaluate the current state of the tensor ... and does not know anything about the previous parametrizations either ....

Can you please elaborate on what do you mean when you say that "a parametrization doesn't have a way to evaluate the current state of the tensor"? Also, it's actaully good that it doesn't know anything about the next pamraetrizations, because that enables composability! The optimization you've mentioned should be applied whenever >> is applied. If the two parametrizations can be composed into a single one, then return the optimized version. If they can't, wrap them in the Sequential-like container. If you get a Sequential and another parametrization to append, then consider if you can't fuse the last two together and return a new Sequential with an optimized last layer.

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On how to implement the low-rank layer: One of them would be to ask the low-rank parametrization to be applied on a tensor of size (2, n) and then implement

class LowRank(nn.Parametrization):
  def forward(self, t):
    a, b = t
    return a.unsqueeze(1) @ b.unsqueeze(0) # outer product

If the writer insists that the parametrization should be applied on a tensor, it is then when parameters in parametrizations come in to play. One would write:

class LowRankVector(nn.Parametrization):
    def __init__(self, n):
        super().__init__()
        self.b = nn.Parameter(torch.Tensor(1, n))

    def forward(self, a):
        return a.unsqueeze(1) @ self.b

Sometimes the latter option with parameters inside parametrizations is the only viable option if one wants to keep the same interface for different objects that implement different optimizations depending on the size of the input.

I do not think that we should support the possibility of applying parametrizations to two different parameters. If one wants to do so, they should couple them in just one parameter and decompose it inside forward. I have not found an example yet in which it is not reasonable to do so.

Given that parametrizations can be stateful, I do not think that returning a new parametrization is a very good option, as doing so would mean that you have to copy that parameter. In the case of pruning, for example, it would mean that one has to copy the mask and having a duplicate mask as well. I do not think that we want to deep-copy a module every time we do >>. To resolve the possible ambiguity of the order in which chain works we could rename it to compose. That way f.compose(g) would read and be semantically as f o g in mathematical notation.

Can you please elaborate on what do you mean when you say that "a parametrization doesn't have a way to evaluate the current state of the tensor

I completely agree that a parametrization should not know anything about the next parametrizations, but it should now about the previous ones. Now, consider the example of an adapter class, that takes a parametrization from R^n to R^n in the constructor and uses it to implement a parametrization form R^k to R^k via composing a function from R^k to R^n on the left and a function from R^n to R^k on the right. This case appears when trying to do orthogonal optimisation on rectangular matrices (R^{nxd}) based on optimization on squared orthogonal matrices (R^{nxn}). If one wants to do so, this adapter class has to prepend a parametrization to that going from R^k to R^n. When one has done so, the parametrization from R^n to R^n is ready to be used on elements of R^k. Furthermore, if it has functions that use the current value of the parametrization, it can just invoke self.evaluate() and, as it is aware of all the previous parametrizations that were chained to it, everything just works. If, on the other hand, we implement everything as a flat Sequential-like method, a parametrization that's in the middle could not evaluate all the parametrizations up to itself.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of them would be to ask the low-rank parametrization to be applied on a tensor of size (2, n) and then implement

The assumption that you can pack all parameters of a parametrization into a single tensor is too strong. For starters, it's just not the right way to model this semantically. Secondly, what if you wanted to have a parametrization that would e.g. restrict a weight to be an affine transformation of another weight? You'd need to keep both the matrix and the bias, and doing that in a single tensor is super annoying.

I do not think that we should support the possibility of applying parametrizations to two different parameters.

I don't see why we should disallow that, especially that both parameters and modules support that. It would seem very incosistent to me if you coudn't do that.

Given that parametrizations can be stateful, I do not think that returning a new parametrization is a very good option, as doing so would mean that you have to copy that parameter.

I don't really follow your reasoning. The parametrization returned by >> can be a thin wrapper around existing parametrization objects, so there would be no need to reallocate them.

To resolve the possible ambiguity of the order in which chain works we could rename it to compose.

I don't think this resolves the issue. Both "chain" and "compose" have the same meaning and don't really clarify the direction in which the parametrizations are composed. Modelling it after a mathematical notation doesn't really help as I personally often struggle to recall which function comes first in f o g.

Furthermore, if it has functions that use the current value of the parametrization, it can just invoke self.evaluate() and, as it is aware of all the previous parametrizations that were chained to it, everything just works.

My question is why would it have functions that do that?

If, on the other hand, we implement everything as a flat Sequential-like method, a parametrization that's in the middle could not evaluate all the parametrizations up to itself.

I think that's a good thing, because it ensures compositionality!


Ok, I think there are two ways to look at composition of parametrizations. One is that it's simply function composition, and they don't have to be aware of each other. This naturally leads to a Sequential-like construction to represent their composition. The other way is to consider them more like a stack. But then, I'd argue that the right way to do it is to apply the parametrizations to parameters of existing parametrizations. This way you get the nice property that you can still call .evaluate() and get the desired result (the "previous" parametrizations will be automatically invoked when their respective attributes are accessed).

Does that make sense? I still don't fully understand why you might want to evaluate a parametrization and have it be aware of the things that come before it so it would be helpful to know. Unfortunately I wasn't able to follow your example. Also, if we go down the second path, then I'm pretty sure that you shouldn't be able to pass in custom tensors to .evaluate(). A parametrization should then simply have forward(self) (i.e. no arguments) that produces the result whenever called (and possibly caches the result).

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding your point of putting a parametrization that depends on several parameters, I have just implemented it so that one may write:
P.register_parametrization(module, ["weight", "bias"], MyParametrization(), name="combined")
where MyParametrization() is a parametrization whose forward takes two arguments. This way we allow to put parametrizations that take several inputs and return several outputs. Now the original tensors are not moved to the parametrization, but we just take a list of pointers to those tensors. The tensor will still be accessible under _tensor_name, that is, we prepend an underscore to its name.

I completely agree on the two ways of looking at parametrizations. Let me describe a more concrete, see if it makes sense. Assume that you have a function that goes from R^{n x n} to SO(n), where SO(n) is the space of the orthogonal matrices. This function reads in Pytorch:

def f(X):
  X = X.tril(-1)
  X = X - X.t()      # X is now skew-symmetric
  return expm(X) # The matrix exponential of a skew-symmetric matrix is orthogonal

Now, assume that we want to parametrize a non-square matrix of shape n x k to have orthogonal columns. It turns out that, if we append a matrix of size n x (n-k) of zeros on the left and pass it to f we get an n x n matrix. If we then take the k first columns of this orthogonal matrix, we get a matrix whose columns are orthogonal. This construction makes quite a bit of mathematical sense, but the why is not important. The idea is that there is an embedding that has to be applied before the f and a projection that has to be applied afterwards. This idea may be abstracted into a class that roughly reads:

class Fibration(P.Parametrization):
    def __init__(self, parametrization):
        self.parametrization = parametrization

    def embedding(self, X):
        raise NotImplementedError()

    def projection(self, X):
        raise NotImplementedError()

    def forward(self, X):
        X = self.embedding(X)
        X = self.parametrization.evaluate(X)
        X = self.projection(X)
        return X

Needless to say that represents a mathematical object of its own. Now, if we have a function inside parametrization that needs to make use of the evaluation of the current state of the parametrization chain, this implementation of Fibration will not work, as self.parametrization has not been registered anywhere.

On why does a parametrization need access to the parameters there are two examples. One is that of the initialization of parameters. In the RAII idea of a parametrization, they would take care of the initialization of their parameters, as they know how are they using them. For example, in the example of f, we are using tril to parametrize the skew-symmetric matrix, but we could also be using triu. As such, it is reasonable that the Parametrization that has f as its forward, is the one that exports methods to initialize the parameters that it manages. Here I have not explained why a parametrization would also need to evaluate the parametrization chain it is on, but I think that it follows reasonably from all this. I can continue this example to explain that if you still feel like it would help.

This RAII idea sort of enforces that a parametrization can only be registered in one parameter / tensor at a time. This idea is what the current implementation follows. If one wanted to effectively have the same parametrization on two parameters, one could create two different parametrizations and make them share their parameters.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for a more detailed write-up. To answer the first part (i.e. your example for why do we even need to be able to evaluate parametrizations on something else than its parameters), I think that we're mixing up two APIs that we really shouldn't. If you never need parameters, then you never should have made this thing a module in the first place. Ideally, you'd keep two APIs for parametrizations: one functional (where all parameters have to be passed in as arguments), and one nn.Module based, which takes care of state management for you, but generally defers to the stateless functional implementation in the forward. That's the pattern we already use for nn and we should continue it here.

On why does a parametrization need access to the parameters

I'm sorry, I can't follow your argument, because I'm not sure which part of my argumentation are you referring to. It would be very helpful if you could cite the relevant parts inline. Sorry 😕

By the way do you think it might be helpful to have a higher bandwidth meeting some time (like a voice call)? I feel like we've been going back and forth a lot, and writing up those arguments and rereading all the previous messages takes a lot of time. Can you please send me an email describing your availability so that we can arrange for a time? I'm hoping that would allow us to find a good design much more quickly!

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to check before this gets merged: can parametrized models still get torch.saved? Pickle doesn't like weird things like this too much, so we should verify this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure it won't work out of the box, but it should be possible to fix it using some custom reducer functions.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never dealt with these corner cases of the pickling module, so I don't know how to go in this regard. Happy to take suggestion.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please try pickling modules that have parametrizations applied and see if that fails or not? We'll think about solutions once we confirm that it is an issue.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails with:

_pickle.PicklingError: Can't pickle <class 'torch.nn.utils.parametrize.ParametrizedLinear'>: attribute lookup ParametrizedLinear on torch.nn.utils.parametrize failed

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the last commit I gave a go at this. Now parametrized layers work well with torch.load / torch.save. Take a look at it, see if you like it. I also added some testing for this.

This approach still throws a warning when being loaded:

UserWarning: Couldn't retrieve source code for container of type ParametrizedLinear. It won't be checked for correctness upon loading.
  warnings.warn("Couldn't retrieve source code for container of " 

I didn't manage to get rid of that one.

It is somewhat hacky, but one literally needs to hack around the limitations of pickle. If you know of a better way I can change it.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: orig is not a lot shorter than original and the latter reads much better. Please don't shorten the names too much.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the wording from the current implementation of the pruning module, but I agree. I'll change it.

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I kept the name orig, because you also have the original() method that returns the original parameter from a parametrization chain. Should I change orig to _original?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's do _original then.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite complicated, and I think it could be improved once you make Parametrization objects less stateful. You should be able to construct multiple of those, connect them into a container akin to nn.Sequential (maybe using this >> operator). Then you get this:

class ParametrizationChain(nn.Module):
  def __init__(self, *parametrizations):
    self.parametrizations = parametrizations

  def register(self, tensor):
    self.parametrizations[0].register(tensor)

  def evaluate(self, tensor=None):
    for p in parametrizations:
      tensor = p.evaluate(tensor)

Note that I've isolated the register function to make sure you don't have to much around too much with the orig attribute in register_parametrization. It should defer setting up of the state of the parametrization to the actual object.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered above.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure about the "append a prefix to the attribute" convention. Maybe it would be nicer to add something a'la ModuleDict called .parametrizations and then you'd be able to do linear.parametrizations.weight?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this convention from the current implementation of the pruning modules and spectral_norm and weight_norm. The problem with adding a new module dict is that one has to implement all the related stuff to iterate over it, and make it appear in __repr__ and so on, which is quite painful if this is implemented outside of Module. Maybe changing it to tensor_name + "_parametrization"?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry can you elaborate on the module dict issues? What's the problem with iterating over it? Also, parameters usually don't appear in __repr__ so what's the deal there?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrizations are modules, and might have parameters and buffers of their own. As such, they should appear in __repr__, and they should be returned by methods like children, modules, so that parameters works correctly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that they necessarily have to appear in the __repr__ so I wouldn't bother with that. About appearing in all the submodule-related methods then yes, they should appear there, but you get that for free if you use ModuleDict, so what's the problem? It's a lot simpler than doing string arithmetic

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't know that those were iterated as sub-modules, but it does make a lot of sense. I just implemented it and works like a charm. It is also much cleaner.

I was thinking of doing the same and an OrderedDict for the caches. What do you think about that? Doing this we would completely get rid of all the nasty string arithmetic in the code, which I don't particularly like either.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this caching logic. Why put it on the original module and not the Parametrization that has produced this? Also, if you put the caching logic into the Parametrization then the getter can always use the same code. Finally, it might be worth considering supporting setattr so that people can easily remove parametrizations by assigning a Parameter to a parametrized one, or overwriting the current one.

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getter can always be the same code, I just separated it to simplify it, but I could just leave it as the same code. The caching logic is in the original module because there should be just one cache per parametrized element. Otherwise, if you put 8 parametrizations on a buffer you would end up with 8 different cached tensors.

The setattr... I am not sure how I feel about it. Is there any use-case where one would want to rewrite a parameter or a tensor with a completely new tensor, rather than removing first the parametrization with P.remove()? It would also require monkey-patching setattr, but I guess that that'd be alright.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good point about not saving the intermediate parametrizations! But you could still work around this by doing this: when the Sequential-like module for parametrizations is called, disable the caching for the duration of forward, to prevent the inner-parametrizations from saving anything. However, once forward is over, the caching is restored and it's the Sequential that will never get its forward called, because the cache lives in this object.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See long answer above on using a Sequential-like module.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.apply will generally be quite expensive and it's always best to avoid it in the hot path. I'm wondering if there's any reason why would someone want to enable caching on a single module and not on a whole block of code (and then we could use a thread-local property that's cheap to enable). Thoughts?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know of any use-case where someone would want to do that. I'd be keen on going with the latter option. Could you elaborate on how to implement it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make it a global or a thread-local flag that's stored in this module.

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a bit more about this one, and I do not think it is possible to do it like this. This apply basically signals the parametrizations when to recompute the cache. These have to be done once per module for every module in the network.
This function is called just twice per minibatch, so I do not think it's that problematic. If one wanted to optimise it, we also offer the more local decorator:

class ParametrizedRNNCell:
  ...

class RNN:
  def __init__(self):
    self.rnn = ParametrizedRNNCell()

  @parametrized_method
  def forward(self, xs):
    out = None
    for x in xs:
      out = self.rnn(x, out)
    return out

which may be activated in the module that uses a parametrized object several times. This performs an apply just on the submodules in the RNN, rather than the whole network.

This could also be done by exposing two free functions like activate_caching and invalidate_caching that take a module and a tensor name that could be executed manually for every parametrized layer at the beginning and end of each iteration in the training loop, in order to avoid the apply call.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a bit more about this one, and I do not think it is possible to do it like this.

Can you elaborate on why? I don't see any issues.

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano Apr 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the example:

def Net(nn.Module):
    def __init__(self):
        self.a = nn.Parameter(...)
        self.b = nn.Parameter(...)
        P.register_parametrization(self, "a", MyParam())
        P.register_parametrization(self, "b", MyParam())

    def forward(self, input_):
        # 1.a use a
        # 2.a use a
        # 1.b use b
        # 3.a use a
        # 2.b use b

net = Net()

for batch in batches:
    with cached(net) as net:
        out = net(batch)

How would you implement cached so that in two different batches the cache is only updated in 1.a and 1.b, but not in any of the other calls in multiple iterations of the loop? I truly do not see how to do it without apply_...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:

_cache_enabled = 0
_cache = {}

@contextmanager
def cache():
  _cache_enabled += 1
  try:
    yield
  finally:
    _cache_enabled -= 1
    if not _cache_enabled:
      _cache = {}


class Parametrization(nn.Module):
  def forward(self, X):
    if not _cache_enabled:
      return self.evaluate(X)
    if id(self) not in _cache:
      _cache[id(self)] = self.evaluate(X)
    return _cache[id(self)]

  def __del__(self):
    _cache.pop(id(self))

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Apr 9, 2020

I just updated the PR with the changes of going from _param to _parametrization and not having the cache as a buffer. I updated the tests accordingly.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely going in a good direction now! It looks great, but I have two last concerns:

  • I thought we decided to get rid of last_parametrization during our chat. Chaining parametrizations should apply them to parametrization parameters instead (so that you get the recursive behavior while forward is computed).
  • I don't see the need to explicitly opt-in parametrizations into caching. Why can't we say that either all of them are cached or none of them are? If there's really a use case for that kind of thing, why can't we make it an opt-out thing?

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just drop they keys?

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused. I thought that we've decided that in the end we shouldn't make parametrizations aware of the chain they are in. Are you still planning to remove this?

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we make accessing .original throw an error like that? Then we wouldn't need evaluate at all

Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is unnecessary

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're using id(module) here and module below! (By the way please add a test case for that if your existing tests haven't caught this error).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did not update the tests because things were changing on a weekly basis. I will write proper thorough tests in a bit, now that we have agreed on the final design

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that is interesting, why are you doing it like that? I thought that all parametrizations should be cached by default when we enter a cached() block

Comment thread test/test_nn.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question: can we do this reparameterization explicit e.g:

model.weight = Skew(Orthogonal(model.weight))

?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment this is not possible, as we need to modify the Module model to be able to register the parametrizations, although I agree that it would be awesome to have an API like this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that it could be done modifying the __setattr__ in nn.Module, but at least for the first iteration @apaszke said that modifying nn.Module was off the table.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a middle ground could be model.register_parametrization("weight", nn.Sequential(Orthogonal(), Skew())). I think explicit chaining is better than implicit chaining - and more amenable to programmatic modifications / pretty printing etc

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano May 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are discussing exactly that point Adam and me. It is true that explicit is better than implicit, but it is not very clear how to make that work with having the parametrizations having control over the parameter itself.

In particular, we would want parametrizations to be able to evaluate the current chain of parametrizations up to them (this is useful to implement parametrizations that change according to certain conditions) and be able to modify the parameter (for initialization purposes). I think of Parametrizations as a bit of an RAII object with respect to the parameter that they are applied to.

The Sequential approach does not give the items that it contains enough information to implement these things.

Copy link
Copy Markdown
Contributor

@vadimkantorov vadimkantorov May 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe another special container for parametrizations can be implemented nn.ParametrizationList(...)?

The initialization concern cannot be solved in the general case :(, since manual initialization is tied to knowing the parametrization used.

If torch.Tensor / nn.Parameter were supported evaluation through some virtual method (e.g. forward() / __call__), then reparametrization buffers and reinitialization could be attached to existing Tensor / Parameter, but this would require modifying basic API / structures (analagous to modifying Module's __setattr__ / __getattr__) - maybe this is a bit similar to quantized tensor representations where some custom calculation is needed to recover the representation that can be operated on using conventional existing methods.

Also one useful design question: will it support well functional interface? (where no modules are used at all)

Copy link
Copy Markdown
Collaborator Author

@lezcano lezcano May 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the attachment to Tensor and Parameter, I think it would be a great idea, as it was discussed in #7313, but it would be quite a big change. Maybe it is more reasonable to wait and see how this way to perform constrained optimisation is adopted, and if people like it, we could revisit it?

At the moment, parametrizations depend on Modules in a fundamental way, as they basically inject a property on the module in place of the weight, so I don't think that they would work well with the functional interface...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, parametrizations depend on Modules in a fundamental way, as they basically inject a property on the module in place of the weight, so I don't think that they would work well with the functional interface...

sigh :(

@emcastillo
Copy link
Copy Markdown
Collaborator

@lezcano hi! I have been talking with @albanD about this PR and since it is mostly done, I can take over the bits that are left to completion if you are ok with it 😄

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Nov 19, 2020

I sent you a message on Slack :)

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Nov 19, 2020

Not saying we can't do it, but I think staying closer to the Tensor and Module interface (similar to the Funsor/Tensule blog) will give a much better UX, even if the technical bits are riskier (in particular re JITing, but the current things will require work there, too).

@vadimkantorov
Copy link
Copy Markdown
Contributor

Another useful API could be "freeze" (and maybe "unfreeze"?), so that the cached versions are not recalculated again. This is a bit akin to BatchNorm "freezing", but there the the updates are related to just the forward pass. In general, this sort of abstract interface could be a useful one for Module - may be good to discuss.

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Nov 19, 2020

I completely agree with @t-vi on that, but @apaszke was very much against this...

About "freeze", is that not what cached and not_cached do in the current implementation? Or were you thinking of some extra functionality on top of that?

@vadimkantorov
Copy link
Copy Markdown
Contributor

About freeze, yep, implementation-wise I think your caching may be doing exactly that.

I proposed to promote it to a more generic Module's API, taking into account that this is also needed for BatchNorm.

The usecase may be calling it before saving the model (like what remove_weight_norm is doing for core weightnorm), so that it can be loaded for inference by not re-parametrized model.

@vadimkantorov
Copy link
Copy Markdown
Contributor

2. So it will always be differentiable.

As long as the user controls it, it's cool :)

  1. But if you have simple quantization that just change the way the weights are stored

Yes, I think it'd be a nice illustrative example, even if not practical or usable irl

Regarding (3), I propose to give a second thought to naming. If we can avoid having the same thing named some technical term like initialize whereas there is some proper mathematical concept behind, it would be nice!

@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Feb 2, 2021

@vadimkantorov
Naming of initialize: As I said before, I am not set for the naming of the initialize_ method. I named it like that for the reasons above, as something like right_inverse would not be fair when one deals with modules. I am very open to suggestions on this one.

@albanD
Thoughts on the extension: There is a subtle detail regarding the proposed generalisation with several input vectors. In order to be able to put a parametrisation on a tensor, you need that tensor to be in the image of the map that you are using as the parametrisation. This is not problematic in your UpdatedWeightNorm example, as the image is surjective (any tensor is in the image), but it is a bit more problematic in the case of the pruning methods. In this case, the image is just the vectors that have zeros in the position that the mask has zeros. This could be solved by doing self.X = parametrisation(X) before registering the parametrisation, but this is clearly annoying. Even more annoying it would be in the case of, say, a rank-1 matrix computed as $u^Tv$. In this case, you would have to sample a rank-1 matrix somehow. This takes a few lines of code, which would make the parametrisations almost always have a sample() method that samples tensors that lie in the image of the map.

In these more elaborated cases such as the rank-1 case, this is not so problematic, as you may always have a reasonable sampler for that space. I had to implement the sample() idea for all the spaces in my GeoTorch library. Any parametrisation P in this library has a P.sample() method. With this I mean that the problem of choosing a reasonable sampler in the image of the map is always there, regardless of the implementation. As such, the question to answer would be:

  • What is a good design that allows having multiple input tensors, but also makes applying pruning methods reasonable simple?

Edit.
On t.detach(). So, this original.data = t came from the pruning methods, where one wants to keep the same object after taking out the parametrisation, (same id and so on), so that one does not need to update the parameters of the optimiser after removing the parametrisation. I believe that the t.detach() solution does not achieve this. What would be then the correct way of achieving this?

@lezcano lezcano changed the title [WIP] Parametrization Functionality Parametrization Functionality Feb 3, 2021
Comment thread docs/source/nn.rst Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsymmetric naming... maybe add_caching/remove_caching. or enable_caching/disable_caching?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With respect to what is this naming not symmetric?

Also, if I had to choose between those two, I would go with add / remove, as the other one might suggest to be a low-level API for the cached context manager, which is not the case.

@vadimkantorov
Copy link
Copy Markdown
Contributor

if remove_parametrization by name is a thing, maybe also worth adding the basic one for weight/param: #46886

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks quite good to me.
The only updates needed are around the doc, the way you detect enabled caching and stricter check on the removal.

One thing that is missing from our next step plan in the issue is to write a Note/tutorial on how to write parametrization modules.
I think this should be done at the same time as we add our "example" parametrization.

Comment thread docs/source/nn.rst Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread torch/nn/utils/parametrize.py Outdated
Comment thread test/test_nn.py Outdated
Comment thread test/test_nn.py Outdated
Comment thread test/test_nn.py Outdated
@lezcano
Copy link
Copy Markdown
Collaborator Author

lezcano commented Mar 1, 2021

@albanD I have corrected all of the things that you mentioned. I left a few comments in the ones that I was not sure about. Please tell me what you think about them.

@lezcano lezcano requested a review from jbschlosser as a code owner March 1, 2021 15:52
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just need a rebase and fixing flake/typing.

Comment thread torch/nn/utils/parametrize.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if caching is disabled
also typo

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good.
Can you rebase on top of viable/strict to make sure all CI is green?

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD merged this pull request in 7aeee28.

aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2021
Summary:
Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are `torch.nn.utils.weight_norm` and low rank parametrizations via the multiplication of a `n x k`  tensor by a `k x m` tensor with `k <= m, n`.

Follows the plan outlined in #33344 (comment). A short summary of the idea is: we call `right_inverse` when registering a parametrization to generate the tensors that we are going to save. If `right_inverse` returns a sequence of tensors, then we save them as `original0`, `original1`...  If it returns a `Tensor` or a sequence of length 1, we save it as `original`.

We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one.

There were a number of choices in the implementation:

If the `right_inverse` returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as:
```python
class Sum(nn.Module):
  def forward(self, X, Y):
    return X + Y
  def right_inverse(Z):
    return Z, torch.zeros_like(Z)
```
rather than having to unpack manually a list or a tuple within the `forward` function.

At the moment the errors are a bit all over the place. This is to avoid having to check some properties of `forward` and `right_inverse` when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible.

The invariants are the following:
1. The following code should be well-formed
```python
X = module.weight
Y = param.right_inverse(X)
assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence)
Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
```
in other words, if `Y` is a `Sequence` of `Tensor`s (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters `param.forward` accepts.

2. Always: `X.dtype == Z.dtype and X.shape == Z.shape`. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor.
3. If it's one-to-one: `X.dtype == Y.dtype`. This is to be able to do `X.set_(Y)` so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse `X` and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of `spectral_norm` and `weight_norm` does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky.

I'm still missing to go over the formatting of the documentation, I'll do that tomorrow.

Pull Request resolved: #58488

Reviewed By: soulitzer

Differential Revision: D29100708

Pulled By: albanD

fbshipit-source-id: b9e91f439cf6b5b54d5fa210ec97c889efb9da38
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Makes possible that the first register parametrization depends on a number of parameters rather than just one. Examples of these types of parametrizations are `torch.nn.utils.weight_norm` and low rank parametrizations via the multiplication of a `n x k`  tensor by a `k x m` tensor with `k <= m, n`.

Follows the plan outlined in pytorch#33344 (comment). A short summary of the idea is: we call `right_inverse` when registering a parametrization to generate the tensors that we are going to save. If `right_inverse` returns a sequence of tensors, then we save them as `original0`, `original1`...  If it returns a `Tensor` or a sequence of length 1, we save it as `original`.

We only allow to have many-to-one parametrizations in the first parametrization registered. The next parametrizations would need to be one-to-one.

There were a number of choices in the implementation:

If the `right_inverse` returns a sequence of parameters, then we unpack it in the forward. This is to allow to write code as:
```python
class Sum(nn.Module):
  def forward(self, X, Y):
    return X + Y
  def right_inverse(Z):
    return Z, torch.zeros_like(Z)
```
rather than having to unpack manually a list or a tuple within the `forward` function.

At the moment the errors are a bit all over the place. This is to avoid having to check some properties of `forward` and `right_inverse` when they are registered. I left this like this for now, but I believe it'd be better to call these functions when they are registered to make sure the invariants hold and throw errors as soon as possible.

The invariants are the following:
1. The following code should be well-formed
```python
X = module.weight
Y = param.right_inverse(X)
assert isinstance(Y, Tensor) or isinstance(Y, collections.Sequence)
Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
```
in other words, if `Y` is a `Sequence` of `Tensor`s (we check also that the elements of the sequence are Tensors), then it is of the same length as the number parameters `param.forward` accepts.

2. Always: `X.dtype == Z.dtype and X.shape == Z.shape`. This is to protect the user from shooting themselves in the foot, as it's too odd for a parametrization to change the metadata of a tensor.
3. If it's one-to-one: `X.dtype == Y.dtype`. This is to be able to do `X.set_(Y)` so that if a user first instantiates the optimiser and then puts the parametrisation, then we reuse `X` and the user does not need to add a new parameter to the optimiser. Alas, this is not possible when the parametrisation is many-to-one. The current implementation of `spectral_norm` and `weight_norm` does not seem to care about this, so this would not be a regression. I left a warning in the documentation though, as this case is a bit tricky.

I'm still missing to go over the formatting of the documentation, I'll do that tomorrow.

Pull Request resolved: pytorch#58488

Reviewed By: soulitzer

Differential Revision: D29100708

Pulled By: albanD

fbshipit-source-id: b9e91f439cf6b5b54d5fa210ec97c889efb9da38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

A class to perform constrained optimization through a parametrization

9 participants