Skip to content

[7/n] Migrate pos_encoding and norm kernels to libtorch stable ABI#38783

Closed
mikaylagawarecki wants to merge 31 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase7
Closed

[7/n] Migrate pos_encoding and norm kernels to libtorch stable ABI#38783
mikaylagawarecki wants to merge 31 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase7

Conversation

@mikaylagawarecki

@mikaylagawarecki mikaylagawarecki commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

Purpose

Stacked on #38757, commits to review https://github.com/vllm-project/vllm/pull/38783/changes/deea6618c38afb4735b442c61e2697c273654292..8754a4250584115db08113e0889313c939d85eb6

Note: some declarations are not deleted from csrc/ops.h despite being moved to csrc/libtorch_stable/ops.h. This is because the CPU build also uses these declarations. These are

  • Layernorm kernels: rms_norm, fused_add_rms_norm
  • Pos encoding kernels: rotary_embedding

Test Plan

pytest tests/kernels/core/test_pos_encoding.py
pytest tests/kernels/core/test_fused_qk_norm_rope.py
pytest tests/kernels/core/test_layernorm.py
pytest tests/kernels/core/test_fused_quant_layernorm.py

Test Result

Screenshot 2026-04-02 at 2 45 33 PM Screenshot 2026-04-02 at 2 45 01 PM Screenshot 2026-04-02 at 2 46 02 PM
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added ci/build nvidia rocm Related to AMD ROCm labels Apr 2, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new _C_stable_libtorch extension to support a stable ABI, enabling better compatibility across different PyTorch versions and environments. It refactors several core kernels and quantization operations to use this stable ABI, including layernorm, positional encoding, and various quantization kernels. Additionally, it enables this stable extension for both CUDA and HIP backends. I have identified a potential compilation issue where the hadacore_transform declaration is placed outside the appropriate conditional compilation block, which may cause build failures on non-CUDA backends.

Comment on lines +158 to +159
torch::stable::Tensor hadacore_transform(torch::stable::Tensor& x,
bool inplace);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The hadacore_transform function is compiled only for CUDA, but its declaration is outside the #ifdef VLLM_CUDA block. This will lead to compilation errors when building for other backends like ROCm/HIP. This declaration should be moved inside the #ifdef VLLM_CUDA block, before the #endif on line 156.

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was pre-existing before this stack see

vllm/csrc/ops.h

Line 296 in 08ed2b9

torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);

@mikaylagawarecki mikaylagawarecki force-pushed the new-stable-abi-phase7 branch 2 times, most recently from 30e40eb to 59af75d Compare April 2, 2026 03:57
@mergify

mergify Bot commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 2, 2026
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Restructure the stable ABI extension build so it compiles on both CUDA
and HIP:
- Widen outer guard to include HIP
- Move CUDA-only sources (CUTLASS, FP4, AWQ, permute_cols) into
  a CUDA-conditional block
- Gate USE_CUDA / CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL to CUDA;
  define USE_ROCM for HIP
- Link PyTorch's bundled libamdhip64.so on ROCm to avoid a dual HIP
  runtime (from 985769a)
- Enable _C_stable_libtorch in setup.py for HIP builds

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Move 9 basic activation ops (silu_and_mul, mul_and_silu, gelu_and_mul,
gelu_tanh_and_mul, fatrelu_and_mul, swigluoai_and_mul, gelu_new,
gelu_fast, gelu_quick) from the _C extension to _C_stable_libtorch.

Convert ATen types/APIs to stable ABI equivalents:
- torch::Tensor -> torch::stable::Tensor
- ATen device guard/stream -> stable accelerator APIs
- VLLM_DISPATCH_FLOATING_TYPES -> VLLM_STABLE_DISPATCH_FLOATING_TYPES
- data_ptr -> mutable_data_ptr

Quantized activation ops (silu_and_mul_quant,
persistent_masked_m_silu_mul_quant) remain in _C.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
@mergify mergify Bot removed the needs-rebase label Apr 2, 2026
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
…ch stable ABI

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
…ed_add_rms_norm_static_fp8_quant) to torch stable ABI

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
…libtorch_stable

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
@mikaylagawarecki mikaylagawarecki changed the title [7/n] libtorch stable ABI [7/n] Migrate pos_encoding and norm kernels to libtorch stable ABI Apr 2, 2026
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Comment thread csrc/ops.h
@@ -91,12 +91,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used by cpu too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

Comment thread csrc/type_convert.cuh
#include <torch/headeronly/util/Half.h>

#ifndef USE_ROCM
#include <cuda.h>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think before torch/all.h or some other torch include pulls this in, but now we need to explicitly include this for CUDA_VERSION used below on line 50

Comment thread CMakeLists.txt
"csrc/libtorch_stable/fused_qknorm_rope_kernel.cu"
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so cleannnn cleanest cmake change so far!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review April 3, 2026 15:09
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 3, 2026
@mergify

mergify Bot commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 8, 2026
stmcgovern added a commit to TorchedHat/pytorch-stable-abi-transform that referenced this pull request Apr 23, 2026
Gaps identified by comparing tool output against a real 1,858-line manual
migration PR (vllm-project/vllm#38783):

Rules.h:
- Add torch::k* scalar type shorthands (kFloat, kBFloat16, kInt8, kInt32, etc.)
- Add c10::/at:: scalar type rewrites (Half, BFloat16, Float8_e4m3fn, etc.)
- Add CUDA check macro rules (C10_CUDA_CHECK, AT_CUDA_CHECK, C10_CUDA_KERNEL_LAUNCH_CHECK)
- Add TORCH_CHECK_NOT_IMPLEMENTED → STD_TORCH_CHECK_NOT_IMPLEMENTED
- Add more method-to-free-function rules (sum, pad, new_zeros, permute, slice,
  index_select, repeat, expand)

AstCallbacks.cpp:
- Register new type names and scalar type shorthands in AST matchers
- Register new method names for method-to-free-function conversion

Verifier.cpp:
- Detect torch::k* shorthands as unstable
- Detect C10_CUDA_CHECK, AT_CUDA_CHECK, C10_CUDA_KERNEL_LAUNCH_CHECK
- Detect TORCH_CHECK_NOT_IMPLEMENTED
- Detect .dtype() usage (unstable caffe2::TypeMeta, use .scalar_type())
- Detect torch::TensorOptions (needs decomposition into explicit args)
- Detect at::Half, c10::Half, c10::BFloat16, c10::Float8_* types
- Detect at::elementSize (use tensor.element_size())
@mergify mergify Bot removed the needs-rebase label May 18, 2026
@mergify

mergify Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@Harry-Chen

Copy link
Copy Markdown
Member

Superseded by newer PRs.

@Harry-Chen Harry-Chen closed this Jun 2, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Jun 2, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants