Skip to content

Rename sparse contiguous() to coalesce(); make out of place; speed up THC Embedding 10x#1302

Merged
soumith merged 4 commits intopytorch:masterfrom
adamlerer:sparse_coalesce
Apr 28, 2017
Merged

Rename sparse contiguous() to coalesce(); make out of place; speed up THC Embedding 10x#1302
soumith merged 4 commits intopytorch:masterfrom
adamlerer:sparse_coalesce

Conversation

@adamlerer
Copy link
Contributor

Rename sparse tensor contiguous to coalesce, and be more sane about in-place vs out-of-place (i.e. always do it in place).

One thing I'm wondering: should the python methods tensor.indices(), tensor.values(), and tensor.__repr__() automatically coalesce first? This would have the advantage of hiding the implementation detail of coalescing from the Python user, guaranteeing expected invariants such as t.indices() == t.coalesce_().indices(). On the other hand, it limits what you can do in python (i.e. if you want to do something that uses the indices without coalescing for efficiency).

Opinions welcome!

@adamlerer
Copy link
Contributor Author

I made some modifications to make sparse accumulation into dense gradients on GPU a lot faster (~10x) by avoiding coalescing and calling indexAdd. I also sped up coalesce by using something closer to what LookupTable does. I still need to go back and unify everything under indexAdd, but that will be trickier.

Here are my previous notes from #1147 :


I dug into speeding up sparse for nn.Embedding based on my benchmark on GPU (https://gist.github.com/adamlerer/865c1a09000c7dc8208e1456255209c2). I think my findings apply more broadly though.

  • Most importantly, the contiguous (or reorder) operation, which merges rows with the same index, is inevitably slow on GPU because it demands a host synchronization. The reason is that the number of unique rows (calculated on GPU) becomes the new nnz, i.e. indices.size[0] / values.size[0] (which is a CPU field). So, it ends up being much faster to compute operations like spcadd (the main op for Embedding update) directly on the non-'contiguous' tensor, using indexAdd or LookupTable_accGradParameters.
    The potential problem with never compacting sparse tensors is that they can grow to unbounded size if you repeatedly call cadd without ever calling contiguous. Maybe make it the user's responsibility to call contiguous when necessary? This can be problematic when the cadd is buried in a library somewhere e.g. in autograd backward. I don't have a good answer here... We could have heuristics for when to compact a tensor. Something like this would be kinda clever (too clever?):
void THCSTensor_(cadd)(r, a, b) {
...
r->minUnique = max(a->minUnique + b->minUnique);
...
if (r->nnz / r->minUnique > COMPACTION_THRESHOLD) {
  THCSTensor_(contiguous)(r);
  r->minUnique = r->nnz;
}
  • LookupTable_accGradParameters is just computing indexAdd, yet it's implemented completely differently. The different implementation was due to a desire a couple years ago for LookupTable to be deterministic therefore not use atomicAdd. But then why does indexAdd still use atomicAdd? Either we care about determinism or we don't... my guess is that we don't any more since even cudnn isn't deterministic.
    I benchmarked then and the non-deterministic one is several times faster, regardless of how many index collision there are. We should pick if we care about determinism, use that to pick an indexAdd kernel in THC, and then delegate LookupTable, spcadd, etc. all to that.

  • Autograd backwards on nn.Embedding spends about 1ms in Python autograd code, which after speeding things up takes >90% of the time for batch sizes below 1e4 (it was taking ~50% of the time before my kernel changes). So batch=1e3 isn't interesting, we should look at batch=1e4.

@adamlerer
Copy link
Contributor Author

Btw I'm still unsure about whether we should make coalesce an in-place op or not. It's an important decision and has some trade-offs

Pro in-place:

  • A coalesced tensor is always faster to operate on than an un-coalesced tensor.
  • Suppose you have an uncoalesced sparse tensor S and you repeatedly call coalescing ops on it:
S = S + S # now uncoalesced
for i in range(N):
  T = S * S # coalesces each time

This is going to be extremely inefficient.

Pro out-of-place:

  • const operands should not be mutable; therefore, if a const operand (e.g. input to cadd) can get coalesced in place, it must be invisible to the Python user. That means that t.indexes(), t.values(), and t.__str__ should all call coalesce. That hides opportunities for optimizations by munging indices and values directly from Python. Then if there are C bugs, you can get nasty behavior like a Heisenbug where when you call print on the tensor, the bug goes away.

Always coalesce

There's a third option which is to always keep tensors in coalesced form. This is probably hard to make efficient. But it's possible that if we move nnz to be stored on the device, we can avoid host syncs and make it reasonably fast.

@adamlerer
Copy link
Contributor Author

Heisenbugs suck. I'll switch to out-of-place coalesce. It will be up to the user to coalesce tensors explicitly if they're going to be operating on them multiple times.

@adamlerer adamlerer changed the title Rename sparse tensor contiguous() to coalesce() Rename sparse contiguous() to coalesce(); make out of place; speed up THC Embedding 10x Apr 21, 2017
@ezyang
Copy link
Contributor

ezyang commented Apr 24, 2017

I'm not too qualified to review this diff, but some thoughts did occur to me:

  1. Does it make sense to still offer an in-place coalesce, even though operators won't call this directly? The main use case would be a user is performing a bunch of sparse ops in place, and then wants to finally coalesce the tensor before passing it off.
  2. It would be great if we had some docs, because there is quite a bit of diversity in how duplicate indexes are handled. SciPy, for example, sums duplicate indexes together (like us, if I understand these diffs correctly), but TensorFlow drops duplicates (Sparse tensor construction given repeated indices tensorflow/tensorflow#371). I know the API is in flux, but it would still be great to have a paragraph or two stating what the basic assumptions in the implementation are.

@adamlerer
Copy link
Contributor Author

Thanks @ezyang

  1. I don't see what value an in-place op provides... the in-place coalesce is no more efficient than out-of place, because a sparse tensor is just a tiny wrapper for {THLongTensor indices, THTensor values}, so all you're creating is a new wrapper. The difference for the in-place one is that if anyone else was holding a reference to the tensor, it would be changed as well. Is there a reason we want that?

  2. Yeah, we definitely need docs. I will at least put a comment in coalesce but we do need to write the full docs for this library soon. I won't be able to get to it until at least late-May, so if you'd like to pick this up, it would be a huge help.

Re: duplicates, the current implementation must sum duplicates because we take advantage of this to do some ops efficiently. cadd just "adds" two sparse tensors by concatenating their indices and values. This way you can add together a bunch of tensors and then accumulate them into a dense tensor, without ever doing a (slow) coalesce.
P.S. We may be able to make coalesce much faster and then wouldn't need this, but that would require some substantial and tricky rewriting (esp. holding nnz on the GPU to avoid a host sync after coalesce).

@adamlerer
Copy link
Contributor Author

Note to self: in the torch tests, self.assertEqual calls x.coalesce() on sparse tensors before comparison. But if coalesce is broken (e.g. to always return an empty tensor) these tests will pass.
If it's fast enough, I want to write a simple Python coalesce, and then replace x.coalesce() with checkedCoalesce(x), which runs the Python and torch versions and asserts they're equal.

ezyang referenced this pull request in ezyang/pytorch-unattached Apr 26, 2017
This is a first cut at documentation on sparse tensors, on
the encouragement of @adamlerer.

It uses the new "coalesce" terminology that is currently
on review in #1302.  There is no documentation for the
methods; I've stuck to documenting aspects that are unlikely
to change.

Towards #1303.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@adamlerer
Copy link
Contributor Author

This is ready for review and merge. @apaszke you'll note that you were right in #1147 to suggest making coalesce out-of-place.

Re: Embedding perf, here's the results of the benchmark script at https://gist.github.com/adamlerer/4159ddb4bd19e87de1b4bb3215a1db6f

$ python benchmark_sparse.py  1
cuda=True
Old done in 0.493923 s
New done in 0.617863 s
New done in 0.382167 s (no optimizer)

cf. 3s for new before this PR.

@ezyang
Copy link
Contributor

ezyang commented Apr 27, 2017

👍

@soumith soumith merged commit f75ab85 into pytorch:master Apr 28, 2017
jjsjann123 added a commit to jjsjann123/pytorch that referenced this pull request Dec 15, 2021
eqy pushed a commit to eqy/pytorch that referenced this pull request Jan 20, 2022
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this pull request Nov 1, 2022
jagadish-amd pushed a commit to jagadish-amd/pytorch that referenced this pull request Jan 29, 2025
=================================================

Temporarily skip test_conv3d_64bit_indexing

- Rocblas API support is requested
- SWDEV-383635 & sub task - SWDEV-390218

Skip ddp apply_optim_in_bwd tests for gloo (pytorch#1302)

To resolve https://ontrack-internal.amd.com/browse/SWDEV-403530 and https://ontrack-internal.amd.com/browse/SWDEV-419837. For more context check upstream issue pytorch#111834

Add skipIfRocmArch decorator for Navi skips (pytorch#1356)

Converted NAVI check as a function (pytorch#1364)

* Moved NAVI check to the test file

* Revised NAVI check as a function

[Navi] [Inductor] Unskip Navi inductor UTs (pytorch#1514)

Relates to https://ontrack-internal.amd.com/browse/SWDEV-461590

Bad import in test_torchinductor and skip torchvision related UT (pytorch#1374)

skip test_inductor_freezing failing UTs (pytorch#1375)

Skip test_mm_triton_kernel_benchmark (pytorch#1376)

* Running triton kernel on ROCM only has one GB/s metric reported

* Update test_kernel_benchmark.py

skip vmapvjpvjp_linalg_householder_product_cuda_float32 (pytorch#1420)

skipIfRocm needs msg parameter

[NO CP] Updated changes to skip few UTs

Imported skipIfRocm in certain test suites (pytorch#1577)

Fixes SWDEV-472397

Added functions imports (pytorch#1521)

Fixes
inductor.test_torchinductor_dynamic_shapes::TestInductorDynamicCUDA::test_item_unbacked_stride_nobreak_cuda

Enable test_public_api_surface (pytorch#1601)

Fixes SWDEV-462410.

Enable this unit test since PyTorch issue
pytorch#104012 has been closed. This
unit test runs fine on MI100/MI300 and upstream.

(cherry picked from commit 0001d4ab5070635cfecc146ee299bbb9fa45ca67)

[rocm6.3_internal_testing] Fixed error string assertion in test_invalid_devices (pytorch#1607)

Fixes pytorch#8974

(cherry picked from commit a688e0a)
jagadish-amd pushed a commit to jagadish-amd/pytorch that referenced this pull request May 15, 2025
=================================================

Temporarily skip test_conv3d_64bit_indexing

- Rocblas API support is requested
- SWDEV-383635 & sub task - SWDEV-390218

Skip ddp apply_optim_in_bwd tests for gloo (pytorch#1302)

To resolve https://ontrack-internal.amd.com/browse/SWDEV-403530 and https://ontrack-internal.amd.com/browse/SWDEV-419837. For more context check upstream issue pytorch#111834

Add skipIfRocmArch decorator for Navi skips (pytorch#1356)

Converted NAVI check as a function (pytorch#1364)

* Moved NAVI check to the test file

* Revised NAVI check as a function

[Navi] [Inductor] Unskip Navi inductor UTs (pytorch#1514)

Relates to https://ontrack-internal.amd.com/browse/SWDEV-461590

Bad import in test_torchinductor and skip torchvision related UT (pytorch#1374)

skip test_inductor_freezing failing UTs (pytorch#1375)

Skip test_mm_triton_kernel_benchmark (pytorch#1376)

* Running triton kernel on ROCM only has one GB/s metric reported

* Update test_kernel_benchmark.py

skip vmapvjpvjp_linalg_householder_product_cuda_float32 (pytorch#1420)

skipIfRocm needs msg parameter

[NO CP] Updated changes to skip few UTs

Imported skipIfRocm in certain test suites (pytorch#1577)

Fixes SWDEV-472397

Added functions imports (pytorch#1521)

Fixes
inductor.test_torchinductor_dynamic_shapes::TestInductorDynamicCUDA::test_item_unbacked_stride_nobreak_cuda

Enable test_public_api_surface (pytorch#1601)

Fixes SWDEV-462410.

Enable this unit test since PyTorch issue
pytorch#104012 has been closed. This
unit test runs fine on MI100/MI300 and upstream.

(cherry picked from commit 0001d4ab5070635cfecc146ee299bbb9fa45ca67)

[rocm6.3_internal_testing] Fixed error string assertion in test_invalid_devices (pytorch#1607)

Fixes pytorch#8974

(cherry picked from commit a688e0a)
(cherry picked from commit b966e44)

[rocm6.4_internal_testing] Skip non_standard_bool_values tests (pytorch#1880)

Fixes SWDEV-509757

(cherry picked from commit 80b4c41)

[rocm6.4_internal_testing] [NAVI32] Skipped sdpa_2 test in test_aot_inductor for Navi32 (pytorch#1882)

The test fails with assertion error "Tensors are not close"

After testing I can confirm that this issue is caused by eager mode
execution specific to navi32 during the test_sdpa_2 run. Made a cross
reference between navi31, navi32 and mi300. AOTInductor results are all
the exact same for all of the archs, only the eager mode fails here for
navi32 with 1.5% difference in tensor values from the gpu run. I assume
that this happens due to fp16-32-16 conversions in eager mode or missing
some if-statements for navi32 specifically.

Simple reproducer to check the values for cpu/gpu/eager/aoti runs.

[gfx1101_test_sdpa_2_issue_reproducer.txt](https://github.com/user-attachments/files/18676367/gfx1101_test_sdpa_2_issue_reproducer.txt)

(cherry picked from commit 896c789)

Fixed rocm skip import issue (pytorch#1949)

skip_if_rocm does not exist in
torch/testing/_internal/common_distributed.py. Use skipIfRocm from
torch/testing/_internal/common_utils.py instead.

(cherry picked from commit cfb673e)

Skip certain unit tests on NAVI (pytorch#1950)

This PR is to skip certain unit tests on NAVI only.
Fixes SWDEV-509011 - test_sac_ilp.py::TestSACILP::test_sac_ilp_case1
Fixes SWDEV-509311 -
test_max_autotune.py::TestMaxAutotune::test_non_contiguous_input_addmm
Fixes SWDEV-510738
test_fsdp_sharded_grad_scaler.py::TestShardedGradScalerParityWithDDP::test_sharded_grad_scaler_found_inf

(cherry picked from commit e86291a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants