Skip to content

perf(qwen3_5): replace einops rearrange with torch.flatten in GatedDe…#20386

Merged
ispobock merged 2 commits intosgl-project:mainfrom
vedantjh2:einops-to-native-qwen3_5
Mar 12, 2026
Merged

perf(qwen3_5): replace einops rearrange with torch.flatten in GatedDe…#20386
ispobock merged 2 commits intosgl-project:mainfrom
vedantjh2:einops-to-native-qwen3_5

Conversation

@vedantjh2
Copy link
Copy Markdown
Contributor

@vedantjh2 vedantjh2 commented Mar 11, 2026

Motivation

einops.rearrange performs Python-level string parsing and validation on every call. In Qwen3_5GatedDeltaNet.forward(), the pattern rearrange(x, "... h d -> ... (h d)") is a simple flatten of the last two dimensions — equivalent to torch.flatten(-2), which is a zero-copy view operation with no parsing overhead.

This call runs on every forward pass through every GDN layer in the model, so the overhead accumulates significantly.

Changes

  • Replace rearrange(core_attn_out, "... h d -> ... (h d)") with core_attn_out.flatten(-2) in Qwen3_5GatedDeltaNet.forward()
  • Remove unused from einops import rearrange import

Profiling Results

Profiled on Qwen3.5-9B, H100 80GB, TP=1, CUDA graphs disabled, using sglang.bench_serving with torch profiler. Perfetto SQL queries filtered to only GatedDeltaNet ancestor slices.

Baseline (einops)

name call_count avg_dur_us total_dur_us
einops/einops.py(561): rearrange 720 12.67 9125.43

Optimized (torch.flatten)

name call_count avg_dur_us total_dur_us min_dur_us max_dur_us
<built-in method flatten> 720 4.74 3410.18 3.97 67.35
aten::flatten 720 2.50 1797.94 2.09 65.37

Summary

  • Per-call: 2.67x faster (12.67 us → 4.74 us)
  • Total time: 63% reduction (9125 us → 3410 us across 720 calls)
  • 5715 us saved per profiled window

Profiling Reproduction Commands

1. Launch server

export SGLANG_TORCH_PROFILER_DIR=/path/to/traces

python -m sglang.launch_server \
  --model-path Qwen/Qwen3.5-9B \
  --port 30000 --tp 1 --disable-cuda-graph

2. Run profiling benchmark

export PYTHONPATH=/path/to/sglang/python:$PYTHONPATH
export SGLANG_TORCH_PROFILER_DIR=/path/to/traces

python -m sglang.bench_serving \
  --backend sglang \
  --model Qwen/Qwen3.5-9B \
  --num-prompts 10 --random-input-len 256 --random-output-len 32 \
  --dataset-name random --profile

3. Perfetto SQL queries

Baseline (einops rearrange in GatedDeltaNet):

SELECT s.name, COUNT(*) AS call_count,
       AVG(s.dur)/1000.0 AS avg_dur_us, SUM(s.dur)/1000.0 AS total_dur_us
FROM slice s
WHERE s.name LIKE '%rearrange%'
  AND EXISTS (SELECT 1 FROM ancestor_slice(s.id) a WHERE a.name LIKE '%GatedDeltaNet%')
GROUP BY s.name;

Optimized (flatten in GatedDeltaNet):

SELECT s.name, COUNT(*) AS call_count,
       AVG(s.dur)/1000.0 AS avg_dur_us, SUM(s.dur)/1000.0 AS total_dur_us,
       MIN(s.dur)/1000.0 AS min_dur_us, MAX(s.dur)/1000.0 AS max_dur_us
FROM slice s
WHERE s.name LIKE '%flatten%'
  AND EXISTS (SELECT 1 FROM ancestor_slice(s.id) a WHERE a.name LIKE '%GatedDeltaNet%')
GROUP BY s.name;
Screenshot 2026-03-11 at 2 40 54 PM Screenshot 2026-03-11 at 2 41 08 PM

Correctness

flatten(-2) is semantically identical to rearrange("... h d -> ... (h d)") — both merge the last two dimensions into one. No behavior change.

Ran GSM8k to test performance and accuracy

Metric Baseline (einops) Optimized (native)
Accuracy 0.905 0.900
Invalid 0.000 0.000
Latency 28.075s 26.704s
Output throughput 789.6 tok/s 828.1 tok/s

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@vedantjh2 vedantjh2 marked this pull request as ready for review March 11, 2026 21:47
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zminglei
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Comment thread python/sglang/srt/models/qwen3_5.py Outdated
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
core_attn_out = core_attn_out.flatten(-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a comment for the shape transfer here as "... h d -> ... (h d)" to keep the readability

@vedantjh2
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

…ltaNet

Replace `rearrange(core_attn_out, '... h d -> ... (h d)')' with
`core_attn_out.flatten(-2)` in Qwen3_5GatedDeltaNet.forward().

Removes the einops dependency from this model file, using a native
PyTorch operation that is semantically equivalent and 2.67x faster
(12.67us -> 4.74us avg per call, measured over 720 calls on H100).
@vedantjh2 vedantjh2 force-pushed the einops-to-native-qwen3_5 branch from c7ce97a to 2588b92 Compare March 12, 2026 01:09
@ispobock ispobock merged commit 9b55a98 into sgl-project:main Mar 12, 2026
88 of 101 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants