perf(kimi_linear): replace einops rearrange with native torch ops in Kimi-Linear KDA path#20396
Merged
ispobock merged 8 commits intosgl-project:mainfrom Mar 20, 2026
Merged
Conversation
…Kimi-Linear KDA path
Replace all 8 einops.rearrange calls with native torch operations in the
Kimi-Linear-48B model's KimiDeltaAttention hot path:
- kimi_linear.py: 2 rearrange calls → unflatten + squeeze/flatten
- kda_backend.py: 6 rearrange calls → unflatten + unsqueeze
Profiled with Kimi-Linear-48B-A3B-Instruct (TP=2, 2xH100):
- Baseline: 12,600 einops::rearrange calls, avg 15.77us, total 198.7ms
- Optimized: 0 rearrange calls; replaced by:
aten::unflatten 10,080 calls avg 3.41us total 34.4ms
aten::flatten 2,520 calls avg 2.57us total 6.5ms
aten::unsqueeze 12,604 calls avg 2.62us total 33.1ms
aten::squeeze 7,520 calls avg 5.29us total 39.8ms
E2E throughput unchanged (2.65 vs 2.72 tok/s, within noise).
Mean TPOT: 71.48ms vs 71.81ms baseline.
5792f26 to
247f8e3
Compare
Collaborator
|
/tag-and-rerun-ci |
zminglei
approved these changes
Mar 12, 2026
Contributor
Author
|
/rerun-failed-ci |
1 similar comment
Contributor
Author
|
/rerun-failed-ci |
Contributor
Author
|
/rerun-failed-ci |
Contributor
Author
|
/rerun-failed-ci |
Collaborator
|
/tag-and-rerun-ci |
Contributor
Author
|
/rerun-failed-ci |
Wangzheee
pushed a commit
to Wangzheee/sglang
that referenced
this pull request
Mar 21, 2026
…Kimi-Linear KDA path (sgl-project#20396)
0-693
pushed a commit
to 0-693/sglang
that referenced
this pull request
Mar 25, 2026
…Kimi-Linear KDA path (sgl-project#20396)
dutsc
pushed a commit
to dutsc/sglang
that referenced
this pull request
Mar 30, 2026
…Kimi-Linear KDA path (sgl-project#20396)
JustinTong0323
pushed a commit
to JustinTong0323/sglang
that referenced
this pull request
Apr 7, 2026
…Kimi-Linear KDA path (sgl-project#20396)
yhyang201
pushed a commit
to yhyang201/sglang
that referenced
this pull request
Apr 22, 2026
…Kimi-Linear KDA path (sgl-project#20396)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
einops.rearrangeadds Python-level overhead (pattern parsing, backend dispatch, shape validation) on every call. In the Kimi-Linear-48B model's KimiDeltaAttention hot path, this is called thousands of times per forward pass.This PR replaces all 8
einops.rearrangecalls with equivalent native PyTorch operations (unflatten,unsqueeze,squeeze,flatten) across 2 files, removing theeinopsimport entirely from both.Files Changed
kimi_linear.py— 2 rearrange calls replaced inKimiDeltaAttention.forward()kda_backend.py— 6 rearrange calls replaced inforward_decode()andforward_extend()Profiling Results
Profiled with Kimi-Linear-48B-A3B-Instruct on 2x H100 80GB (TP=2),
--disable-cuda-graph, 10 prompts (512 input / 128 output tokens). Results scoped tonn.Module: KimiDeltaAttention_*spans via Perfetto SQL.Baseline (einops)
einops::rearrangeOptimized (native torch)
aten::unflattenaten::unsqueezeaten::squeezeaten::flatten~1.75x reduction in reshape-related CPU overhead (198.70ms -> 113.74ms).
Correctness
Ran the same test in
registered/models/test_kimi_linear_models.pyLocally launched SGLang server with the command:
then launched the test with the following command and result:
All replacements produce bitwise identical outputs to the original einops operations, verified with
torch.equalacross multiple tensor shapes, dtypes, and contiguity states.