Skip to content

Enable mixed type LayerNorm kernel for NSA indexer#12044

Merged
Fridge003 merged 4 commits intosgl-project:mainfrom
akhilg-nv:layer_norm
Nov 4, 2025
Merged

Enable mixed type LayerNorm kernel for NSA indexer#12044
Fridge003 merged 4 commits intosgl-project:mainfrom
akhilg-nv:layer_norm

Conversation

@akhilg-nv
Copy link
Copy Markdown
Contributor

@akhilg-nv akhilg-nv commented Oct 24, 2025

Motivation

Currently we cast input to layernorm to fp32 and use native Torch layernorm, then cast back. Instead we pull in a more efficient TRT-LLM kernel (via flashinfer) that supports mixed precision inputs.

Modifications

Add flashinfer layernorm kernel and apply in NSA indexer.

Accuracy Tests

python -m sglang.launch_server --model-path model_fp4/ --tp 4 --dp 4 --enable-dp-attention --reasoning-parser deepseek-v3

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3

sample results

# with old layernorm
Repeat: 8, mean: 0.783
Scores: ['0.788', '0.793', '0.808', '0.783', '0.742', '0.788', '0.783', '0.778']
====================/198 [22:35<00:52,  4.04s/it]
Writing report to /tmp/gpqa_model_fp4_.html
{'chars': np.float64(1464.3333333333333), 'chars:std': np.float64(356.735535690296), 'score:std': np.float64(0.41573970964154905), 'score': np.float64(0.7777777777777778)}
Writing results to /tmp/gpqa_model_fp4_.json
Total latency: 1422.309 s
Score: 0.778

# with new layernorm

Repeat: 8, mean: 0.785
Scores: ['0.758', '0.763', '0.788', '0.753', '0.823', '0.808', '0.793', '0.793']
====================198 [22:02<2:29:43, 51.63s/it]
Writing report to /tmp/gpqa_model_fp4_.html
{'chars': np.float64(1474.1515151515152), 'chars:std': np.float64(341.7505399975447), 'score:std': np.float64(0.40520665017240837), 'score': np.float64(0.7929292929292929)}
Writing results to /tmp/gpqa_model_fp4_.json
Total latency: 1290.363 s
Score: 0.793

nsys trace shows the new kernel being used:

image

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment thread python/sglang/test/test_layernorm.py
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
@akhilg-nv akhilg-nv changed the title Enable mixed type LayerNorm kernel for NSA Enable mixed type LayerNorm kernel for NSA indexer Oct 24, 2025
@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Oct 24, 2025

@akhilg-nv Please add the nsys profile result showing how much time the three kernels were using before applying the change

Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Comment on lines +315 to +318
if not self.elementwise_affine:
return self.forward_native(x)

if _flashinfer_layernorm_available and x.dtype == torch.bfloat16 and self.dtype == torch.float32:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine the two if branches

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some benchmarking and found for most cases I looked at, using the flashinfer kernel with weight = ones and bias = zeros is still faster than torch or torch compile. So I combined these execution paths (by default I now initialize weight/bias to ones/zeros). Ideally we can make the affine transform optional to the flashinfer kernel, but this shouldn't affect DSv3.2 perf anyway.

Comment thread python/sglang/test/test_layernorm.py Outdated
Comment thread python/sglang/test/test_layernorm.py
Comment thread python/sglang/srt/layers/layernorm.py Outdated
@akhilg-nv
Copy link
Copy Markdown
Contributor Author

@Fridge003 Could you provide insight on resolving the failing CI/CD tests?

I see errors that seem unrelated to my changes, like:

  File "/sglang-checkout/python/sglang/srt/models/deepseek_v2.py", line 2098, in forward_absorb_fused_mla_rope_prepare
    forward_batch.attn_backend.forward_metadata
AttributeError: 'HybridAttnBackend' object has no attribute 'forward_metadata'. Did you mean: 'init_forward_metadata'?

and

  File "/sglang-checkout/python/sglang/srt/models/deepseek_v2.py", line 2097, in forward_absorb_fused_mla_rope_prepare
    attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot unpack non-iterable ForwardMetadata object

@Fridge003
Copy link
Copy Markdown
Collaborator

@akhilg-nv That's unrelated to this PR

@Fridge003 Fridge003 merged commit e607850 into sgl-project:main Nov 4, 2025
60 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants