Removes threadfence from topk kernel to improve AMD performance#145536
Removes threadfence from topk kernel to improve AMD performance#145536
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145536
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 7e0d3ec with merge base 0d28188 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Are these comments still in-sync with e.g.,
auto ks_to_find_buffer = allocator.allocate(2 * numInputSlices * sizeof(uint32_t)); below
There was a problem hiding this comment.
In the kernel, the size of ks_to_find_in is still num_slices, so the kernel comment is correct. The allocation is now twice the size because we cannot update inplace.
|
@ngimel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@ngimel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rch#145536) Also marginally improves cuda perf Pull Request resolved: pytorch#145536 Approved by: https://github.com/eqy
…rch#145536) Also marginally improves cuda perf Pull Request resolved: pytorch#145536 Approved by: https://github.com/eqy
…rch#145536) Also marginally improves cuda perf Pull Request resolved: pytorch#145536 Approved by: https://github.com/eqy
…rch#145536) Also marginally improves cuda perf Pull Request resolved: pytorch#145536 Approved by: https://github.com/eqy
…rch#145536) Also marginally improves cuda perf Pull Request resolved: pytorch#145536 Approved by: https://github.com/eqy
…to eliminate redundant memory access (#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR #145536](#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR #145536](#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: #164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
Also marginally improves cuda perf