Skip to content

A class to perform constrained optimization through a parametrization #28937

@lezcano

Description

@lezcano

🚀 Feature

A class that implements the following pattern:

x = nn.Parameter(torch.rand(4,5))
y = f(x)
# Use y here

computing f and its derivative only once per minibatch. Note that the value of y and the gradient of x do not depend on the elements of the minibatch.

Motivation

A useful example

The exponential of matrices maps the skew-symmetric matrices onto the orthogonal matrices with positive determinant. One way to perform optimization with orthogonal constraints is the following:

x = nn.Parameter(torch.rand(4,4))
aux = x.triu(diagonal=1)
aux = aux - aux.t()     # aux is a skew-symmetric matrix 
y = exp(aux)            # y is an orthogonal matrix

Problems implementing this under the current API

If one implements this naïvely within a forward method, the exponential of matrices will be computed every time the forward function is called, even if the value of x has not changed from one call to the next one (imagine this as parametrizing the kernel of an RNN).

One way of avoiding this is to have some kind of cache:

def Parametrized(nn.Module):
  def __init__(self, x): # x is a torch.Tensor here
    self.x = nn.Parameter(x)
    self._y_cached = self.register_buffer(f(x))
    self.computed = True
  @property
  def y(self):
    if not self.computed:
      self._y_cached = f(x)
      self.computed = True
    return self._y_cached
  def set_dirty(self):
    self.computed = False

This does the trick in most situations, but it still fails in the case when self.computed == False and we call self.y in a torch.no_grad() context. In that case, self._y_cached is updated, but grad_fn won't be computed, so we will get an error when calling loss.backward(). This can be solved modifying the if statement as:

    if not self.computed or (not self.y_cached.grad_fn and torch.is_grad_enabled):

To make this implementation work, one still hast to set self._y_cached.retain_grad() to make it a leaf variable, and call the set_dirty function appropriately for every instance of Parametrized in the model.

This looks like a very convoluted solution that is full of nuances, like the implementation of most of caches.

The bigger picture

One can perform optimization on a manifold M (the orthogonal matrices, or the positive definite, the matrix with non-zero determinant, the Grassmannian...) through a surjective function f : R^{n x m} -> M just by parametrizing M in terms of a linear space. This class would allow for a rather direct implementation of optimization constrained to a manifold just by choosing an appropriate function. The exponential of matrices is just one example of this but, for example, we also have another widely used example as L.mm(L.t()) where L is a lower-triangular matrix with positve elements on the diagonal. This function maps this set surjectively to the manifold of symmetric positive matrices.

A simpler application would be to have a tensor that has entries in [-1, 1], parametrized through tanh or non-negative entries, through relu. In these cases, the improvement would not be very big in the computational side of things, but in how concise the model is, expressing better what one wants to do.

One could also think of having a few constraints implemented for a few manifolds, like some of those mentioned above. This way Pytorch would support optimization on manifolds without any major changes in the API.

Pitch [EDITED]

The class nn.Module would have a new method register_constrained_parameter that could be used as follows:

lin = nn.Linear(10,20)
lin.register_constrained_parameter(name="weight", function=f, update="auto")

When called, register_constrained_parameter would create an inner attribute of type nn.Parameter named _{name}_unconstrained and a buffer named _{name}_cache. When assigning to lin.{name}, this would be deviated through nn.Module.__setattr__ to lin._{name}_unconstrained. The caching would be implemented in nn.Module.__getattr__.

update is a flag that takes "auto" or "manual". If update == "manual", the user has to notify the tensor that has been updated like

lin.weight.updated()

This has to be done after optim.step() has been called, but also if the user wants to modify the inner-variable manually (more on this below).

If the update == "auto", then we will set an updated flag to True when the gradients with respect the parametrization are computed. This can easily be done with a register_hook on the tensor. This idea was proposed and discussed in #7313. This solution solves most practical use cases, when the user does not fiddle with the parameters manually during training, even less between the call to loss.backwards() and optim.step().
The implementation "auto" would not work correctly by itself if module.{name} is called between these two. Another case would be the following:

class MyModule(nn.Module):
  def __init__(self, f):
    self.W = nn.Parameter(torch.tensor(3,4))
    # Register the constrained parameter
    self.register_constrained_parameter(name="W", function=f, update="auto")
    # Retrieve W
    a = self.W 
    # update _W_unconstrained
    torch.nn.init.eye_(self._W_unconstrained) 

In these cases the user would have to call the module.W.update() manually. after self._W_unconstrained has been updated.

Further details:

  • When using a parameter constrained by a function, the parameter will change shape as to that of the image of f. In other words, if f creates a constant vector from a real number, R -> R^n, and we use this f to constrain a parameter that is a real number called "number", then accessing mymodule.number would return a vector.

  • Assignments to the variable:
    I am not sure about the __setattr__ part being deviated to _{name}_unconstrained, as it might not feel natural. This is a tricky one, as the user can edit the variable module.{name} in many ways (assignment, in-place operations, module.{name}.data...) and all these changes would not be reflected in the inner state of the module. Maybe this can be just put in the documentation as, at the end of the day, we are creating a buffer _{name}_cache and a buffer _{name}_unconstrained. The user should be aware of these, as they are a natural abstraction for a parametrization, and should access these if they want to reinitialise the variable or modify the cached buffer temporally.
    A safer way to go about this would be to raise an exception if the user tries to modify a constrained variable in any way (this might not be easy at all). On the other hand, if we are willing to go down this road, we might as well notify directly to all the tensors whenever they have been updated, and the trick with the register_hook described before would not be necessary.

This whole scheme is implemented following this approach in
https://github.com/Lezcano/expRNN/blob/master/parametrization.py
as part of a slightly more involved scheme to parametrize manifolds.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions