Skip to content

sparse.mm backward: performance improvements#94991

Closed
nikitaved wants to merge 41 commits intogh/nikitaved/24/basefrom
gh/nikitaved/24/head
Closed

sparse.mm backward: performance improvements#94991
nikitaved wants to merge 41 commits intogh/nikitaved/24/basefrom
gh/nikitaved/24/head

Conversation

@nikitaved
Copy link
Copy Markdown
Collaborator

@nikitaved nikitaved commented Feb 16, 2023

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94991

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e862db2:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

nikitaved added a commit that referenced this pull request Feb 16, 2023
ghstack-source-id: 1526dd8
Pull Request resolved: #94991
@nikitaved nikitaved marked this pull request as draft February 16, 2023 17:53
nikitaved added a commit that referenced this pull request Feb 17, 2023
ghstack-source-id: 39235c7
Pull Request resolved: #94991
@nikitaved nikitaved marked this pull request as ready for review February 19, 2023 10:59
@nikitaved
Copy link
Copy Markdown
Collaborator Author

@albanD , @soulitzer , could you please also have a look?

@nikitaved nikitaved requested a review from Skylion007 February 19, 2023 10:59
nikitaved added a commit that referenced this pull request Feb 19, 2023
ghstack-source-id: 66cfe12
Pull Request resolved: #94991
@nikitaved nikitaved added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 19, 2023
`torch.sparse.mm` - faster and without syncs in "most" cases.




[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved nikitaved requested a review from pearu April 19, 2023 11:30
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few clarification questions, otherwise LGTM!

// larger. This function prepares inputs for `sparse_mask` such that `t` is
// projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
// as coalesced all while `mask` is set to uncoalesced.
// The result of this projectionk is going to be uncoalesced, so it is up to the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this comment!

This function returns three tensors. How do these relate to the inputs and to the result of the projection?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, for

lhs, rhs, lhs_hash_opt = sparse_mask_like_prepare_sparse_inputs(t, mask)

we have the following invariants

lhs.indices() == t.indices()[lhs_hash]
lhs.values() == t.values()[lhs_hash]
rhs == mask

where lhs_hash = lhs_hash_opt or slice(None), lhs may be a copy or a view of t, rhs is uncoaleced view of mask.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment in #94991 (comment).

// the other way around depending on which arguments are coalesced and which are
// larger. This function prepares inputs for `sparse_mask` such that `t` is
// projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
// as coalesced all while `mask` is set to uncoalesced.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage of returning mask as uncoalesced?

Copy link
Copy Markdown
Collaborator Author

@nikitaved nikitaved Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If mask is uncoalesced, but the other argument is not, the COO intersection kernel will return a tensor with the same indices as mask but will do binary search of mask hashes into the hashes of the other argument's indices all without calls to sort and coalesce. The COO intersection kernel is heavily optimized to take advantage of is_coalesced of either argument, see https://github.com/pytorch/pytorch/pull/92976/files. As such, calling an intersection kernel with arguments (a, b) might produce results with a.indices() or b.indices() depending on which is more performant (i.e. does not sync) based on whether a.is_coalesced or b.is_coalesced. In order to force this kernel to do what we want for sparse_mask, we mark certain arguments as "coalesced" (if need to after sort) and mask as uncoalesced to make sure the result has indices mask.indices().

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (forcing a coalesced input to be uncoalesced) sounds like a convoluted way to control COO intersection kernel functionality (to ensure that the result indices are the input indices).

Copy link
Copy Markdown
Collaborator Author

@nikitaved nikitaved Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still better than writing things from scratch, imho. The COO intersection kernel is not a public function with fixed semantics, but we want it to be fast and without any syncs if possible, so that, say, mul(a, b) and mul(b, a) produce the very same tensor and without any syncs if either a or b is coalesced. In my opinion we can sacrifice some clarity for performance here given that the array of ops that we can implement with this kernel is quite substantial and is performance critical (do you remember @amjames was working on removing calls to coalesce?) Fast sparse on CUDA is what sells it in the first place, and there is still some room for improvement in the intersection kernel as I see it... But I hear you, and will probably modify interface a bit now that I know the use cases better. Current design assumed just mul, I did not know about sparse_mask and its importance back then (and other use cases, like non-symmetric context). That could be a nice follow-up once performance is here...

`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved
Copy link
Copy Markdown
Collaborator Author

@pearu , do you have any other concerns? I will provide comfortable interfaces to COO intersection kernels as a follow-up to indeed reduce cognitive burden...

Copy link
Copy Markdown
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, @nikitaved!

if (grad_order == 0) {
auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
return a_grad.mul(mask_ones_like(a.coalesce()));
return sparse_mask_like_grad(a, a_grad);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so essentially we're starting to write manual fusions. If we had torch.compile support we could presumably generate the code for this pattern. Since this is a hot path this is ok, but it's going to start to become difficult to test correctness.

Since _sparse_mask_projection is a native function you could even write a test that compares it to the pattern that you're fusing.

Is this a common general pattern? Maybe we can apply it in more locations. If we generalize the fused composite native function that you're adding here a bit, maybe it applies to more locations? It can then also subsume all the logic of sparse_mask_like_grad and we can explicitly say "This function represents the fusion of this pattern". That'll be easier to understand to future maintainers.

Copy link
Copy Markdown
Collaborator Author

@nikitaved nikitaved May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch , alternatively, we can make this function composite implicit now that sparse_mask does support backward with COO inputs. This way we can remove the code for backward altogether. The result: much less code, still sync-less backward albeit potentially slower compared to this impl (there is only one way to call backward, there is no back-and-force between grad and input, because of the fixed semantics of sparse_mask). But, yes, it is a generic pattern forced by some "sparse-semantics" functions which might become obsolete with differentiable sparse_mask or something similar but more flexible in enforcing projection direction (aka a public interface for our COO intersection primitive kernel).
Another possibility: we can create a public function which is doing either _sparse_mask_projection and/or sparse_mask (whichever is faster) which is explicitly differentiable and could be used to enforce sparse semantics all while having flexible backward. Then any sparse semantics function will be a composition of this method and the underlying logic it implements.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should go with your first approach for now and keep the amount of code low and then work on torch.compile support to provide these fusions. We can keep what you have for now as a reference and goal for that integration. Otherwise we'll have to undo these manual fusions again once we have torch.compile support.

Copy link
Copy Markdown
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good, I'm just worried that future maintainers might find this difficult to grok quickly. What if we reframe this work as implementing a fused function for a specific pattern of operations?

@nikitaved
Copy link
Copy Markdown
Collaborator Author

nikitaved commented Jun 12, 2023

@cpuhrsch , do you mind merging this for now to untangle some PR dependencies? Once sparse_mask backward is in, I will make sparse-mm composite implicit and we will likely remove a LOT of already in the codebase code...

`torch.sparse.mm` - faster and without syncs in "most" cases.




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved nikitaved requested a review from cpuhrsch June 12, 2023 13:26
@cpuhrsch
Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/nikitaved/24/head branch June 16, 2023 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: autograd Related to torch.autograd, and the autograd engine in general module: sparse Related to torch.sparse open source release notes: sparse release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants