sparse.mm backward: performance improvements#94991
sparse.mm backward: performance improvements#94991nikitaved wants to merge 41 commits intogh/nikitaved/24/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94991
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e862db2: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@albanD , @soulitzer , could you please also have a look? |
[ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
pearu
left a comment
There was a problem hiding this comment.
I have a few clarification questions, otherwise LGTM!
| // larger. This function prepares inputs for `sparse_mask` such that `t` is | ||
| // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it | ||
| // as coalesced all while `mask` is set to uncoalesced. | ||
| // The result of this projectionk is going to be uncoalesced, so it is up to the |
There was a problem hiding this comment.
Thanks for adding this comment!
This function returns three tensors. How do these relate to the inputs and to the result of the projection?
There was a problem hiding this comment.
IIUC, for
lhs, rhs, lhs_hash_opt = sparse_mask_like_prepare_sparse_inputs(t, mask)we have the following invariants
lhs.indices() == t.indices()[lhs_hash]
lhs.values() == t.values()[lhs_hash]
rhs == maskwhere lhs_hash = lhs_hash_opt or slice(None), lhs may be a copy or a view of t, rhs is uncoaleced view of mask.
| // the other way around depending on which arguments are coalesced and which are | ||
| // larger. This function prepares inputs for `sparse_mask` such that `t` is | ||
| // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it | ||
| // as coalesced all while `mask` is set to uncoalesced. |
There was a problem hiding this comment.
What is the advantage of returning mask as uncoalesced?
There was a problem hiding this comment.
If mask is uncoalesced, but the other argument is not, the COO intersection kernel will return a tensor with the same indices as mask but will do binary search of mask hashes into the hashes of the other argument's indices all without calls to sort and coalesce. The COO intersection kernel is heavily optimized to take advantage of is_coalesced of either argument, see https://github.com/pytorch/pytorch/pull/92976/files. As such, calling an intersection kernel with arguments (a, b) might produce results with a.indices() or b.indices() depending on which is more performant (i.e. does not sync) based on whether a.is_coalesced or b.is_coalesced. In order to force this kernel to do what we want for sparse_mask, we mark certain arguments as "coalesced" (if need to after sort) and mask as uncoalesced to make sure the result has indices mask.indices().
There was a problem hiding this comment.
This (forcing a coalesced input to be uncoalesced) sounds like a convoluted way to control COO intersection kernel functionality (to ensure that the result indices are the input indices).
There was a problem hiding this comment.
It still better than writing things from scratch, imho. The COO intersection kernel is not a public function with fixed semantics, but we want it to be fast and without any syncs if possible, so that, say, mul(a, b) and mul(b, a) produce the very same tensor and without any syncs if either a or b is coalesced. In my opinion we can sacrifice some clarity for performance here given that the array of ops that we can implement with this kernel is quite substantial and is performance critical (do you remember @amjames was working on removing calls to coalesce?) Fast sparse on CUDA is what sells it in the first place, and there is still some room for improvement in the intersection kernel as I see it... But I hear you, and will probably modify interface a bit now that I know the use cases better. Current design assumed just mul, I did not know about sparse_mask and its importance back then (and other use cases, like non-symmetric context). That could be a nice follow-up once performance is here...
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
|
@pearu , do you have any other concerns? I will provide comfortable interfaces to COO intersection kernels as a follow-up to indeed reduce cognitive burden... |
pearu
left a comment
There was a problem hiding this comment.
LGTM! Thanks, @nikitaved!
| if (grad_order == 0) { | ||
| auto a_grad = _sparse_sparse_matmul(grad, b.conj().t()); | ||
| return a_grad.mul(mask_ones_like(a.coalesce())); | ||
| return sparse_mask_like_grad(a, a_grad); |
There was a problem hiding this comment.
Ok, so essentially we're starting to write manual fusions. If we had torch.compile support we could presumably generate the code for this pattern. Since this is a hot path this is ok, but it's going to start to become difficult to test correctness.
Since _sparse_mask_projection is a native function you could even write a test that compares it to the pattern that you're fusing.
Is this a common general pattern? Maybe we can apply it in more locations. If we generalize the fused composite native function that you're adding here a bit, maybe it applies to more locations? It can then also subsume all the logic of sparse_mask_like_grad and we can explicitly say "This function represents the fusion of this pattern". That'll be easier to understand to future maintainers.
There was a problem hiding this comment.
@cpuhrsch , alternatively, we can make this function composite implicit now that sparse_mask does support backward with COO inputs. This way we can remove the code for backward altogether. The result: much less code, still sync-less backward albeit potentially slower compared to this impl (there is only one way to call backward, there is no back-and-force between grad and input, because of the fixed semantics of sparse_mask). But, yes, it is a generic pattern forced by some "sparse-semantics" functions which might become obsolete with differentiable sparse_mask or something similar but more flexible in enforcing projection direction (aka a public interface for our COO intersection primitive kernel).
Another possibility: we can create a public function which is doing either _sparse_mask_projection and/or sparse_mask (whichever is faster) which is explicitly differentiable and could be used to enforce sparse semantics all while having flexible backward. Then any sparse semantics function will be a composition of this method and the underlying logic it implements.
There was a problem hiding this comment.
I think we should go with your first approach for now and keep the amount of code low and then work on torch.compile support to provide these fusions. We can keep what you have for now as a reference and goal for that integration. Otherwise we'll have to undo these manual fusions again once we have torch.compile support.
cpuhrsch
left a comment
There was a problem hiding this comment.
I think this looks good, I'm just worried that future maintainers might find this difficult to grok quickly. What if we reframe this work as implementing a fused function for a specific pattern of operations?
|
@cpuhrsch , do you mind merging this for now to untangle some PR dependencies? Once |
`torch.sparse.mm` - faster and without syncs in "most" cases. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
torch.sparse.mm- faster and without syncs in "most" cases.Stack from ghstack (oldest at bottom):
cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @ezyang @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7