Skip to content

sparse gradcheck: reparametrize some tests to remove masked=True#98490

Closed
nikitaved wants to merge 51 commits intogh/nikitaved/35/basefrom
gh/nikitaved/35/head
Closed

sparse gradcheck: reparametrize some tests to remove masked=True#98490
nikitaved wants to merge 51 commits intogh/nikitaved/35/basefrom
gh/nikitaved/35/head

Conversation

@nikitaved
Copy link
Copy Markdown
Collaborator

@nikitaved nikitaved commented Apr 6, 2023

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with masked=False will imply success with masked=True. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with torch.sparse_mask so that the gradcheck succeeds when masked=False. Hence, we can remove masked=True altogether.

Stack from ghstack (oldest at bottom):

cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @ezyang @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98490

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e1bcfa2:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 6, 2023
nikitaved added a commit that referenced this pull request Apr 6, 2023
@nikitaved nikitaved marked this pull request as draft April 6, 2023 10:07
nikitaved added a commit that referenced this pull request Apr 11, 2023
nikitaved added a commit that referenced this pull request Apr 11, 2023
nikitaved added a commit that referenced this pull request Apr 12, 2023
nikitaved added a commit that referenced this pull request Apr 12, 2023
@pearu
Copy link
Copy Markdown
Collaborator

pearu commented Jun 19, 2023

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with masked=False will imply success with masked=True.

Consider the function torch.mm that supports sparse inputs and it should use non-masked semantics. However, your statement does not hold for the following example case:

>>> a = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_(True)
>>> torch.autograd.gradcheck(lambda x: torch.mm(x, x).to_dense(masked_grad=False), (a,), masked=False)
True
>>> torch.autograd.gradcheck(lambda x: torch.mm(x, x).to_dense(masked_grad=True), (a,), masked=True)
<snip>
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [2.0000, 3.0000, 0.0000, 2.0000],
        [1.0000, 0.0000, 3.0000, 1.0000],
        [0.0000, 1.0000, 2.0000, 6.0000]], dtype=torch.float64)
analytical:tensor([[0., 1., 2., 0.],
        [2., 3., 0., 2.],
        [1., 0., 3., 1.],
        [0., 1., 2., 6.]], dtype=torch.float64)

The statement holds when the input sparse tensor is a full tensor:

>>> a = torch.tensor([[10, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_(True)
>>> torch.autograd.gradcheck(lambda x: torch.mm(x, x).to_dense(masked_grad=True), (a,), masked=True)
True

Next, consider torch.sparse.mm that implements mm with masked semantics:

>>> a = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_(True)
>>> torch.autograd.gradcheck(lambda x: torch.sparse.mm(x, x).to_dense(masked_grad=True), (a,), masked=True)
True

but with masked=False, the gradcheck will fail:

>>> torch.autograd.gradcheck(lambda x: torch.sparse.mm(x, x).to_dense(masked_grad=False), (a,), masked=False)
<snip>
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.0000, 1.0000, 2.0000, 0.0000],
        [2.0000, 3.0000, 0.0000, 2.0000],
        [1.0000, 0.0000, 3.0000, 1.0000],
        [0.0000, 1.0000, 2.0000, 6.0000]], dtype=torch.float64)
analytical:tensor([[0., 0., 0., 0.],
        [2., 3., 0., 2.],
        [1., 0., 3., 1.],
        [0., 1., 2., 6.]], dtype=torch.float64)

unless the input sparse tensor is full:

>>> a = torch.tensor([[10, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_(True)
>>> torch.autograd.gradcheck(lambda x: torch.sparse.mm(x, x).to_dense(masked_grad=False), (a,), masked=False)
True

Based on the above, I cannot confirm that masked=True can be removed.

…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved
Copy link
Copy Markdown
Collaborator Author

nikitaved commented Jun 19, 2023

@pearu, it can be once the gradients are properly mapped to the manifold with sparse_mask. sparse.mm is just a composition of torch.mm and sparse_mask. More specifically, torch.sparse.mm should be equivalent to lambda x, y: torch.mm(x.sparse_mask(x), y.sparse_mask(y)). Sometimes we also want to restrict the in-flowing gradients, so one other option is to have lambda x, y: res = torch.mm(x.sparse_mask(x), y.sparse_mask(y)); res.sparse_mask(res). A combination of sparse_mask and mm gives us much more that just torch.sparse.mm.

Since masked=False performs a densification of sparse inputs, we can still test this function with

x_mask = x.detach().clone()
y_mask = y.detach().clone()

def mm(x, y):
    x = x.sparse_mask(x_mask) # project x onto x_mask, the in-flowing grad will have the same indices as x_mask
    y = y.sparse_mask(y_mask) # project y onto y_mask, the in-flowing grad will have the same indices as y_mask
    res = torch.mm(x, y)
    # make sure that in-flowing grads have the same indices as res
    return res.sparse_mask(res)

gradcheck(lambda x, y: mm(x, y).to_dense(masked_grad=False), (x, y), masked=False)

…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@pearu
Copy link
Copy Markdown
Collaborator

pearu commented Jun 22, 2023

it can be once the gradients are properly mapped to the manifold with sparse_mask.
sparse.mm is just a composition of torch.mm and sparse_mask

I agree. This is what masked tensor support should handle, that is, torch.mm on masked tensors is equivalent to torch.mm and torch.sparse_mask combination on value-mask pairs as you exemplified.

torch.sparse.mm on sparse tensors is equivalent to torch.mm on masked tensors where the masks are defined by the sparsity patterns of inputs. Recall, defining the mask via sparsity pattern is something we want to get rid of eventually.

Our aim is to deprecate/eliminate torch.sparse.mm (together with the corresponding backward function) in favor of supporting masked tensors in torch.mm. Once we have this support and have applied the necessary deprecations procedures to handle BC-breaking changes (e.g. defining torch.sparse.mm in terms of torch.mm and converting sparse tensors to masked tensors), the masked kw argument can safely be removed from gradcheck.

Until then, removing the usage of masked kw argument as in this PR, is premature, IMHO. Especially when this is achieved by removing tests that exercise the masked=True case. When removing such tests, we risk that any future changes to gradcheck may introduce undetectable bugs for the gradcheck(..., masked=True) support which we are not ready to drop yet.

In general, we should deprecate a feature before removing the corresponding tests. In this PR, tests are removed but deprecating the feature is not immediately possible because masked tensors support is not ready yet.

…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved
Copy link
Copy Markdown
Collaborator Author

nikitaved commented Jun 23, 2023

I am in no rush to merge. It serves as a proof of concept: gradcheck does not need masked=True as it complicates things, is error-prune when it comes to implementing backward formulas, and it gives a false feeling of security with gradients being correct upon returning True.

…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added 13 commits June 28, 2023 12:46
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
…d=True"

Most of the sparse functions that work with sparse tensors assume that sparse is an optimization, so a green check with `masked=False` will imply success with `masked=True`. Functions that assume the sparse semantics and do not explicitly ignore grads outside of the sparse pattern can be re-parametrized with `torch.sparse_mask` so that the gradcheck succeeds when `masked=False`. Hence, we can remove `masked=True` altogether. 




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@pytorch pytorch deleted a comment from pytorch-bot bot Jul 3, 2023
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Sep 1, 2023

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 1, 2023
@github-actions github-actions bot closed this Oct 1, 2023
@facebook-github-bot facebook-github-bot deleted the gh/nikitaved/35/head branch November 1, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request module: autograd Related to torch.autograd, and the autograd engine in general module: sparse Related to torch.sparse open source Stale topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants