[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI#38841
[8/n] Migrate merge_attn_states, mamba, sampler to torch stable ABI#38841mikaylagawarecki wants to merge 37 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new _C_stable_libtorch extension to support stable ABI, enabling cross-platform compatibility for custom operations. It includes significant refactoring of existing CUDA kernels to use torch::stable::Tensor and STD_TORCH_CHECK for stable ABI compliance, alongside updates to build configurations to support both CUDA and HIP. However, several critical issues were identified regarding const-correctness and in-place tensor modification logic. Specifically, multiple functions pass tensors as const references while attempting to obtain mutable pointers, and the in-place logic in hadacore_transform fails to correctly update the original storage when new storage is allocated.
| x = torch::stable::reshape(x, {-1, had_size}); | ||
|
|
||
| auto numel = x.numel(); | ||
| if (numel % 256 != 0) { | ||
| x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); | ||
| x = torch::stable::pad(x, {0, 0, 0, (256 - numel % 256) / had_size}); | ||
| } | ||
|
|
||
| if (x.stride(-1) != 1) { | ||
| x = x.contiguous(); | ||
| x = torch::stable::contiguous(x); |
There was a problem hiding this comment.
The in-place logic for hadacore_transform is broken when the input tensor x is reassigned due to reshape, pad, or contiguous. Reassigning the torch::stable::Tensor& x reference only updates the local handle and does not modify the original storage passed by the caller. If inplace is true, any operations that create new storage (like pad or contiguous) must eventually copy the result back to the original tensor's storage using torch::stable::copy_. Currently, the copy_ check at line 807 will always be false because out and x share the same handle after reassignment, meaning the original storage remains unchanged.
There was a problem hiding this comment.
pre-existing
| void selective_scan_fwd(const torch::stable::Tensor &u, const torch::stable::Tensor &delta, | ||
| const torch::stable::Tensor &A, const torch::stable::Tensor &B, const torch::stable::Tensor &C, | ||
| const std::optional<torch::stable::Tensor> &D_, | ||
| const std::optional<torch::stable::Tensor> &z_, | ||
| const std::optional<torch::stable::Tensor> &delta_bias_, | ||
| bool delta_softplus, | ||
| const std::optional<torch::Tensor> &query_start_loc, | ||
| const std::optional<torch::Tensor> &cache_indices, | ||
| const std::optional<torch::Tensor> &has_initial_state, | ||
| const torch::Tensor &ssm_states, | ||
| const std::optional<torch::stable::Tensor> &query_start_loc, | ||
| const std::optional<torch::stable::Tensor> &cache_indices, | ||
| const std::optional<torch::stable::Tensor> &has_initial_state, | ||
| const torch::stable::Tensor &ssm_states, | ||
| // used to identify padding entries if cache_indices provided | ||
| // in case of padding, the kernel will return early | ||
| int64_t null_block_id, | ||
| int64_t block_size, | ||
| const std::optional<torch::Tensor> &block_idx_first_scheduled_token, | ||
| const std::optional<torch::Tensor> &block_idx_last_scheduled_token, | ||
| const std::optional<torch::Tensor> &initial_state_idx, | ||
| const std::optional<torch::Tensor> &cu_chunk_seqlen, | ||
| const std::optional<torch::Tensor> &last_chunk_indices) { | ||
| const std::optional<torch::stable::Tensor> &block_idx_first_scheduled_token, | ||
| const std::optional<torch::stable::Tensor> &block_idx_last_scheduled_token, | ||
| const std::optional<torch::stable::Tensor> &initial_state_idx, | ||
| const std::optional<torch::stable::Tensor> &cu_chunk_seqlen, | ||
| const std::optional<torch::stable::Tensor> &last_chunk_indices) { |
There was a problem hiding this comment.
The C++ signature for selective_scan_fwd uses const torch::stable::Tensor& for tensors that are actually modified in-place (like delta and ssm_states), as indicated by the Tensor! markers in ops.def and the use of const_cast to obtain mutable pointers later in the code (e.g., at lines 576-577). This violates const correctness and can lead to undefined behavior. These parameters should be passed as non-const references (torch::stable::Tensor&).
There was a problem hiding this comment.
pre-existing
| torch::stable::Tensor const& b_qweight, | ||
| torch::stable::Tensor const& b_scales, | ||
| std::optional<torch::stable::Tensor> const& b_zeros, bool has_zp, | ||
| torch::stable::Tensor& b_qweight_reorder, | ||
| torch::stable::Tensor& b_scales_reorder, | ||
| std::optional<torch::stable::Tensor> const& b_zeros_reorder, | ||
| const int64_t K, const int64_t N, const int64_t N_32align) { |
There was a problem hiding this comment.
b_zeros_reorder is passed as a const reference, but mutable_data_ptr() is called on it at line 151. This is inconsistent and will likely fail to compile if mutable_data_ptr() is correctly implemented as a non-const member function. It should be passed as a non-const reference (std::optional<torch::stable::Tensor>&).
There was a problem hiding this comment.
pre-existing
| torch::stable::Tensor const& out, | ||
| torch::stable::Tensor const& lse, | ||
| torch::stable::Tensor const& q_nope, | ||
| torch::stable::Tensor const& q_pe, | ||
| torch::stable::Tensor const& kv_c_and_k_pe_cache, | ||
| torch::stable::Tensor const& seq_lens, | ||
| torch::stable::Tensor const& page_table, | ||
| torch::stable::Tensor const& workspace, | ||
| double sm_scale, |
| "selective_scan_fwd(Tensor! u, Tensor! delta," | ||
| "Tensor! A, Tensor! B, Tensor! C," | ||
| "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," | ||
| "bool delta_softplus," | ||
| "Tensor? query_start_loc," | ||
| "Tensor? cache_indices," | ||
| "Tensor? has_initial_state," | ||
| "Tensor! ssm_states," | ||
| "int null_block_id," | ||
| "int block_size," | ||
| "Tensor? block_idx_first_scheduled_token," | ||
| "Tensor? block_idx_last_scheduled_token," | ||
| "Tensor? initial_state_idx," | ||
| "Tensor? cu_chunk_seqlen," | ||
| "Tensor? last_chunk_indices) -> ()"); |
There was a problem hiding this comment.
The ops.def for selective_scan_fwd incorrectly marks almost all input tensors as mutable (Tensor!). Only tensors that are actually modified by the kernel (like delta if used for output, and ssm_states) should have the ! suffix. Marking immutable inputs as mutable prevents PyTorch from performing certain optimizations and is misleading for users of the API.
There was a problem hiding this comment.
pre-existing
9c49b55 to
3ac0d1b
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration. Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration. Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Restructure the stable ABI extension build so it compiles on both CUDA and HIP: - Widen outer guard to include HIP - Move CUDA-only sources (CUTLASS, FP4, AWQ, permute_cols) into a CUDA-conditional block - Gate USE_CUDA / CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL to CUDA; define USE_ROCM for HIP - Link PyTorch's bundled libamdhip64.so on ROCm to avoid a dual HIP runtime (from 985769a) - Enable _C_stable_libtorch in setup.py for HIP builds Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Move 9 basic activation ops (silu_and_mul, mul_and_silu, gelu_and_mul, gelu_tanh_and_mul, fatrelu_and_mul, swigluoai_and_mul, gelu_new, gelu_fast, gelu_quick) from the _C extension to _C_stable_libtorch. Convert ATen types/APIs to stable ABI equivalents: - torch::Tensor -> torch::stable::Tensor - ATen device guard/stream -> stable accelerator APIs - VLLM_DISPATCH_FLOATING_TYPES -> VLLM_STABLE_DISPATCH_FLOATING_TYPES - data_ptr -> mutable_data_ptr Quantized activation ops (silu_and_mul_quant, persistent_masked_m_silu_mul_quant) remain in _C. Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Migrate static_scaled_fp8_quant, dynamic_scaled_fp8_quant, and dynamic_per_token_scaled_fp8_quant from _C to _C_stable_libtorch. Shared headers (common.cuh, utils.cuh) updated to work with both targets: utils.cuh uses torch::headeronly types; common.cuh uses Schema changed from (int,int)? to int[]? for group_shape to work with TORCH_BOX (std::tuple is not trivially copyable). Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
696c62d to
65a245b
Compare
65a245b to
523616f
Compare
|
minor nits that you should fix before merging, but otherwise change is coherent |
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Replace ATen headers with torch stable ABI equivalents: - torch::Tensor/at::Tensor -> torch::stable::Tensor - at::cuda::OptionalCUDAGuard -> torch::stable::accelerator::DeviceGuard - at::cuda::getCurrentCUDAStream() -> get_current_cuda_stream() - TORCH_CHECK -> STD_TORCH_CHECK - at::ScalarType -> torch::headeronly::ScalarType - at::Half/at::BFloat16 -> c10::Half/c10::BFloat16 - C10_CUDA_CHECK/C10_HIP_CHECK -> manual cudaGetLastError() + STD_TORCH_CHECK - CHECK_SHAPE macro rewritten with custom helper for IntHeaderOnlyArrayRef Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
523616f to
af6ab01
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Superseded by newer PRs. |
Commits to review
https://github.com/vllm-project/vllm/pull/38841/changes/ea6c06bc84378e855ce82ff08f302346d5dc4983..af6ab01a4f5055230635a499fc328afc444f3dba
Purpose
Stacked on #38783
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.