Skip to content

perf(kimi_linear): replace einops rearrange with native torch ops in Kimi-Linear KDA path#20396

Merged
ispobock merged 8 commits intosgl-project:mainfrom
vedantjh2:einops-to-native-kimi
Mar 20, 2026
Merged

perf(kimi_linear): replace einops rearrange with native torch ops in Kimi-Linear KDA path#20396
ispobock merged 8 commits intosgl-project:mainfrom
vedantjh2:einops-to-native-kimi

Conversation

@vedantjh2
Copy link
Copy Markdown
Contributor

@vedantjh2 vedantjh2 commented Mar 12, 2026

Motivation

einops.rearrange adds Python-level overhead (pattern parsing, backend dispatch, shape validation) on every call. In the Kimi-Linear-48B model's KimiDeltaAttention hot path, this is called thousands of times per forward pass.

This PR replaces all 8 einops.rearrange calls with equivalent native PyTorch operations (unflatten, unsqueeze, squeeze, flatten) across 2 files, removing the einops import entirely from both.

Files Changed

  • kimi_linear.py — 2 rearrange calls replaced in KimiDeltaAttention.forward()
  • kda_backend.py — 6 rearrange calls replaced in forward_decode() and forward_extend()

Profiling Results

Profiled with Kimi-Linear-48B-A3B-Instruct on 2x H100 80GB (TP=2), --disable-cuda-graph, 10 prompts (512 input / 128 output tokens). Results scoped to nn.Module: KimiDeltaAttention_* spans via Perfetto SQL.

python -m sglang.launch_server \
  --model-path /shared/public/elr-models/moonshotai/Kimi-Linear-48B-A3B-Instruct \
  --port 30000 --tp 2 --disable-cuda-graph --trust-remote-code

Baseline (einops)

SELECT   name,   COUNT(*) AS calls,   ROUND(AVG(dur) / 1e3, 4) AS avg_ms,   ROUND(SUM(dur) / 1e3, 2) AS total_ms FROM slice WHERE name = 'einops/einops.py(561): rearrange'   AND id IN (     SELECT s.id FROM slice s     JOIN ancestor_slice(s.id) a ON a.name GLOB 'nn.Module: KimiDeltaAttention_*'   ) GROUP BY name;wwwww
name calls avg (us) total (ms)
einops::rearrange 12,600 15.77 198.70

Optimized (native torch)

SELECT   name,   COUNT(*) AS calls,   ROUND(AVG(dur) / 1e3, 4) AS avg_ms,   ROUND(SUM(dur) / 1e3, 2) AS total_ms FROM slice WHERE name IN ('aten::unflatten', 'aten::flatten', 'aten::unsqueeze', 'aten::squeeze')   AND id IN (     SELECT s.id FROM slice s     JOIN ancestor_slice(s.id) a ON a.name GLOB 'nn.Module: KimiDeltaAttention_*'   ) GROUP BY name;
name calls avg (us) total (ms)
aten::unflatten 10,080 3.41 34.41
aten::unsqueeze 12,604 2.62 33.06
aten::squeeze 7,520 5.29 39.80
aten::flatten 2,520 2.57 6.47
Total 113.74

~1.75x reduction in reshape-related CPU overhead (198.70ms -> 113.74ms).

Screenshot 2026-03-11 at 5 34 17 PM Screenshot 2026-03-11 at 5 34 23 PM Screenshot 2026-03-11 at 5 36 39 PM Screenshot 2026-03-11 at 5 36 48 PM

Correctness

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
Metric Baseline (einops) Optimized (native torch)
Accuracy 0.915 0.910
Invalid 0.000 0.000
Latency 41.6s 40.0s
Throughput 471.9 tok/s 490.7 tok/s

Ran the same test in registered/models/test_kimi_linear_models.py
Locally launched SGLang server with the command:

python -m sglang.launch_server   --model-path /shared/public/elr-models/moonshotai/Kimi-Linear-48B-A3B-Instruct   --port 30000 --tp 2 --trust-remote-code

then launched the test with the following command and result:

python -m sglang.test.few_shot_gsm8k --num-shots 5 --num-questions 200 --max-new-tokens 512 --parallel 128 --port 30000 --data-path /shared/public/data/gsm8k/test.jsonl
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:09<00:00, 21.65it/s]
Accuracy: 0.900
Invalid: 0.000
Latency: 9.288 s
Output throughput: 2059.970 token/s

All replacements produce bitwise identical outputs to the original einops operations, verified with torch.equal across multiple tensor shapes, dtypes, and contiguity states.

…Kimi-Linear KDA path

Replace all 8 einops.rearrange calls with native torch operations in the
Kimi-Linear-48B model's KimiDeltaAttention hot path:

- kimi_linear.py: 2 rearrange calls → unflatten + squeeze/flatten
- kda_backend.py: 6 rearrange calls → unflatten + unsqueeze

Profiled with Kimi-Linear-48B-A3B-Instruct (TP=2, 2xH100):
- Baseline: 12,600 einops::rearrange calls, avg 15.77us, total 198.7ms
- Optimized: 0 rearrange calls; replaced by:
    aten::unflatten  10,080 calls avg 3.41us  total 34.4ms
    aten::flatten     2,520 calls avg 2.57us  total  6.5ms
    aten::unsqueeze  12,604 calls avg 2.62us  total 33.1ms
    aten::squeeze     7,520 calls avg 5.29us  total 39.8ms

E2E throughput unchanged (2.65 vs 2.72 tok/s, within noise).
Mean TPOT: 71.48ms vs 71.81ms baseline.
@vedantjh2 vedantjh2 force-pushed the einops-to-native-kimi branch from 5792f26 to 247f8e3 Compare March 12, 2026 02:21
@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 12, 2026

/tag-and-rerun-ci

@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@ispobock ispobock merged commit db995fb into sgl-project:main Mar 20, 2026
220 of 248 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants