Skip to content

[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI#38841

Closed
mikaylagawarecki wants to merge 37 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase8
Closed

[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI#38841
mikaylagawarecki wants to merge 37 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase8

Conversation

@mikaylagawarecki

@mikaylagawarecki mikaylagawarecki commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

Commits to review

https://github.com/vllm-project/vllm/pull/38841/changes/ea6c06bc84378e855ce82ff08f302346d5dc4983..af6ab01a4f5055230635a499fc328afc444f3dba

Purpose

Stacked on #38783

Test Plan

pytest tests/kernels/attention/test_merge_attn_states.py
pytest tests/kernels/test_top_k_per_row.py
pytest tests/kernels/test_apply_repetition_penalties.py
pytest tests/kernels/mamba/test_mamba_ssm.py

Test Result

Screenshot 2026-04-02 at 4 23 36 PM Screenshot 2026-04-02 at 4 23 57 PM Screenshot 2026-04-03 at 1 43 21 PM
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mikaylagawarecki mikaylagawarecki changed the title New stable abi phase8 [8/n] Migrate to torch stable ABI Apr 2, 2026
@mergify mergify Bot added ci/build nvidia rocm Related to AMD ROCm labels Apr 2, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new _C_stable_libtorch extension to support stable ABI, enabling cross-platform compatibility for custom operations. It includes significant refactoring of existing CUDA kernels to use torch::stable::Tensor and STD_TORCH_CHECK for stable ABI compliance, alongside updates to build configurations to support both CUDA and HIP. However, several critical issues were identified regarding const-correctness and in-place tensor modification logic. Specifically, multiple functions pass tensors as const references while attempting to obtain mutable pointers, and the in-place logic in hadacore_transform fails to correctly update the original storage when new storage is allocated.

Comment on lines +782 to +790
x = torch::stable::reshape(x, {-1, had_size});

auto numel = x.numel();
if (numel % 256 != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size}));
x = torch::stable::pad(x, {0, 0, 0, (256 - numel % 256) / had_size});
}

if (x.stride(-1) != 1) {
x = x.contiguous();
x = torch::stable::contiguous(x);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The in-place logic for hadacore_transform is broken when the input tensor x is reassigned due to reshape, pad, or contiguous. Reassigning the torch::stable::Tensor& x reference only updates the local handle and does not modify the original storage passed by the caller. If inplace is true, any operations that create new storage (like pad or contiguous) must eventually copy the result back to the original tensor's storage using torch::stable::copy_. Currently, the copy_ check at line 807 will always be false because out and x share the same handle after reassignment, meaning the original storage remains unchanged.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existing

Comment on lines +663 to +681
void selective_scan_fwd(const torch::stable::Tensor &u, const torch::stable::Tensor &delta,
const torch::stable::Tensor &A, const torch::stable::Tensor &B, const torch::stable::Tensor &C,
const std::optional<torch::stable::Tensor> &D_,
const std::optional<torch::stable::Tensor> &z_,
const std::optional<torch::stable::Tensor> &delta_bias_,
bool delta_softplus,
const std::optional<torch::Tensor> &query_start_loc,
const std::optional<torch::Tensor> &cache_indices,
const std::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states,
const std::optional<torch::stable::Tensor> &query_start_loc,
const std::optional<torch::stable::Tensor> &cache_indices,
const std::optional<torch::stable::Tensor> &has_initial_state,
const torch::stable::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t null_block_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::stable::Tensor> &initial_state_idx,
const std::optional<torch::stable::Tensor> &cu_chunk_seqlen,
const std::optional<torch::stable::Tensor> &last_chunk_indices) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The C++ signature for selective_scan_fwd uses const torch::stable::Tensor& for tensors that are actually modified in-place (like delta and ssm_states), as indicated by the Tensor! markers in ops.def and the use of const_cast to obtain mutable pointers later in the code (e.g., at lines 576-577). This violates const correctness and can lead to undefined behavior. These parameters should be passed as non-const references (torch::stable::Tensor&).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existing

Comment on lines +107 to +113
torch::stable::Tensor const& b_qweight,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& b_zeros, bool has_zp,
torch::stable::Tensor& b_qweight_reorder,
torch::stable::Tensor& b_scales_reorder,
std::optional<torch::stable::Tensor> const& b_zeros_reorder,
const int64_t K, const int64_t N, const int64_t N_32align) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

b_zeros_reorder is passed as a const reference, but mutable_data_ptr() is called on it at line 151. This is inconsistent and will likely fail to compile if mutable_data_ptr() is correctly implemented as a non-const member function. It should be passed as a non-const reference (std::optional<torch::stable::Tensor>&).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existing

std::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,

Comment on lines +37 to 45
torch::stable::Tensor const& out,
torch::stable::Tensor const& lse,
torch::stable::Tensor const& q_nope,
torch::stable::Tensor const& q_pe,
torch::stable::Tensor const& kv_c_and_k_pe_cache,
torch::stable::Tensor const& seq_lens,
torch::stable::Tensor const& page_table,
torch::stable::Tensor const& workspace,
double sm_scale,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

out is passed as a const reference, but it is the output tensor of the operation (marked as Tensor! out in ops.def). It should be passed as a non-const reference to allow obtaining a mutable data pointer for the kernel.

Comment on lines +434 to +448
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int null_block_id,"
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ops.def for selective_scan_fwd incorrectly marks almost all input tensors as mutable (Tensor!). Only tensors that are actually modified by the kernel (like delta if used for output, and ssm_states) should have the ! suffix. Marking immutable inputs as mutable prevents PyTorch from performing certain optimizations and is misleading for users of the API.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existing

@mergify

mergify Bot commented Apr 3, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 3, 2026
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Restructure the stable ABI extension build so it compiles on both CUDA
and HIP:
- Widen outer guard to include HIP
- Move CUDA-only sources (CUTLASS, FP4, AWQ, permute_cols) into
  a CUDA-conditional block
- Gate USE_CUDA / CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL to CUDA;
  define USE_ROCM for HIP
- Link PyTorch's bundled libamdhip64.so on ROCm to avoid a dual HIP
  runtime (from 985769a)
- Enable _C_stable_libtorch in setup.py for HIP builds

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Move 9 basic activation ops (silu_and_mul, mul_and_silu, gelu_and_mul,
gelu_tanh_and_mul, fatrelu_and_mul, swigluoai_and_mul, gelu_new,
gelu_fast, gelu_quick) from the _C extension to _C_stable_libtorch.

Convert ATen types/APIs to stable ABI equivalents:
- torch::Tensor -> torch::stable::Tensor
- ATen device guard/stream -> stable accelerator APIs
- VLLM_DISPATCH_FLOATING_TYPES -> VLLM_STABLE_DISPATCH_FLOATING_TYPES
- data_ptr -> mutable_data_ptr

Quantized activation ops (silu_and_mul_quant,
persistent_masked_m_silu_mul_quant) remain in _C.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Migrate static_scaled_fp8_quant, dynamic_scaled_fp8_quant, and
dynamic_per_token_scaled_fp8_quant from _C to _C_stable_libtorch.

Shared headers (common.cuh, utils.cuh) updated to work with both
targets: utils.cuh uses torch::headeronly types; common.cuh uses

Schema changed from (int,int)? to int[]? for group_shape to work
with TORCH_BOX (std::tuple is not trivially copyable).

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
@mergify mergify Bot removed the needs-rebase label Apr 3, 2026
@mikaylagawarecki mikaylagawarecki force-pushed the new-stable-abi-phase8 branch 3 times, most recently from 696c62d to 65a245b Compare April 3, 2026 17:41
@mikaylagawarecki mikaylagawarecki changed the title [8/n] Migrate to torch stable ABI [8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI Apr 3, 2026
Comment thread csrc/libtorch_stable/attention/merge_attn_states.cu Outdated
Comment thread csrc/libtorch_stable/mamba/selective_scan_fwd.cu
Comment thread csrc/libtorch_stable/topk.cu Outdated
@janeyx99

janeyx99 commented Apr 3, 2026

Copy link
Copy Markdown
Contributor

minor nits that you should fix before merging, but otherwise change is coherent

Comment thread csrc/libtorch_stable/torch_bindings.cpp Outdated
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Replace ATen headers with torch stable ABI equivalents:
- torch::Tensor/at::Tensor -> torch::stable::Tensor
- at::cuda::OptionalCUDAGuard -> torch::stable::accelerator::DeviceGuard
- at::cuda::getCurrentCUDAStream() -> get_current_cuda_stream()
- TORCH_CHECK -> STD_TORCH_CHECK
- at::ScalarType -> torch::headeronly::ScalarType
- at::Half/at::BFloat16 -> c10::Half/c10::BFloat16
- C10_CUDA_CHECK/C10_HIP_CHECK -> manual cudaGetLastError() + STD_TORCH_CHECK
- CHECK_SHAPE macro rewritten with custom helper for IntHeaderOnlyArrayRef

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 7, 2026
@mergify

mergify Bot commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 8, 2026
@mergify mergify Bot removed the needs-rebase label May 18, 2026
@mergify

mergify Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@Harry-Chen

Copy link
Copy Markdown
Member

Superseded by newer PRs.

@Harry-Chen Harry-Chen closed this Jun 2, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 2, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants