[fusion] add composable fusion pass framework#10549
[fusion] add composable fusion pass framework#10549DevashishLal-CB wants to merge 21 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Things Pending as of now
|
56aec73 to
33aa252
Compare
|
Can we add a sgl-kernel fuse kernel pass example? Such as |
b2c8368 to
6990ed4
Compare
6990ed4 to
ba01b82
Compare
@BBuf Added the example for topk_softmax fusion, Also added rmsnorm_quant fusion pass with tests This MR is ready for review, will look into cuda graph support and do it as a separate MR Will collaborate with @yuan-luo |
Cool, we'll review ASAP. |
| from sglang.srt.server_args import ServerArgs | ||
|
|
||
|
|
||
| class FusionManager(CustomGraphPass): |
There was a problem hiding this comment.
Instead of FusionManager, we prefer to do abstraction and form a PassManager, in which fusion is one type of all the Pass types like llvm pass concept. There can be other Pass types like AsyncTPPass, AllReduceFusionPass, RMSNormQuantFusionPass and etc.
Refer to https://github.com/sgl-project/sglang/pull/10987/files#diff-61475915ef47a86d47da62c647cd346f64c4b702c94728ab84172aed428e4fc0
for more details.
| from sglang.srt.server_args import ServerArgs | ||
|
|
||
| try: | ||
| from vllm import _custom_ops # noqa: F401 |
There was a problem hiding this comment.
I'll port over the kernel
| @@ -147,14 +156,21 @@ def patch_model( | |||
| tp_group.ca_comm = backup_ca_comm | |||
|
|
|||
|
|
|||
| def set_torch_compile_config(): | |||
| def set_torch_compile_config(server_args, model_config): | |||
There was a problem hiding this comment.
Parameters in def should have type.
| @@ -1788,6 +1788,8 @@ def init_device_graphs(self): | |||
| return | |||
|
|
|||
| if self.device != "cpu" and self.server_args.disable_cuda_graph: | |||
| if self.server_args.enable_torch_compile: | |||
There was a problem hiding this comment.
Do we need to conduct torch_compile in case of disable_cuda_graph?
There was a problem hiding this comment.
I haven't looked into it much but two passes I added weren't working with cuda graph enabled, also I am not sure about if all other hw platforms support cuda graph
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import logging |
There was a problem hiding this comment.
We'd better put this configuration file in the python/sglang/srt/configs/ directory.
| return torch.compile( | ||
| torch.no_grad()(forward), | ||
| mode=os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"), | ||
| dynamic=False, | ||
| ) |
There was a problem hiding this comment.
You have to use fullgraph=True. It's merge stopper, isn't it?
There was a problem hiding this comment.
Currently dynamo encounters graph breaks on attention, a unified attention op would solve it as done here #10062
| @@ -114,6 +114,21 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): | |||
| _to_torch(sub, reverse, num_tokens) | |||
|
|
|||
|
|
|||
| def _torch_compile_wrapper(forward): | |||
There was a problem hiding this comment.
No more design patterns in 2025 except Wrapper and Manager, right? [sarcasm]
Your function is Decorator, not Wrapper.
There was a problem hiding this comment.
Yeah, this entry point is suppose to be a placeholder, once we have a custom backend (which will be required by piecewise cuda graphs) that would manage this invocation, I didn't wanna do a big diff
| from sglang.srt.compilation.fusion.fusion_pass import FusionPass | ||
|
|
||
|
|
||
| class RMSNormQuantPass(FusionPass): |
There was a problem hiding this comment.
Not clear from name and namespace: what type of quantization is supported: fp8 / int8/ int4 or binary?
13f96ea to
69c1778
Compare
Signed-off-by: Devashish Lal <devcode@fb.com>
69c1778 to
59ca839
Compare
…#2243) <!-- .github/pull_request_template.md --> ## 📌 Description FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ### Reference I have been working on adding custom fusion passes to sglang as part of the following [RFC](sgl-project/sglang#10118) and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following [MR](sgl-project/sglang#10549) ### Implementation I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required. For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits ```cpp #include <c10/util/Float8_e4m3fn.h> template <typename T, typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || std::is_same_v<T, c10::Float8_e4m3fnuz> || std::is_same_v<T, int8_t>>> struct quant_type_max { static constexpr T val() { return std::numeric_limits<T>::max(); } }; ``` The best option in my mind is to introduce `include/flashinfer/fp8.h` containing something similar to the above snippet, and also support e5m2 ### Tests atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantized RMSNorm and fused quantized RMSNorm (residual-add) with configurable scale, eps, and PDL toggle. * Supports FP16/FP8 paths and optional per-token or per-tensor scaling; outputs are clamped for quantized formats. * **Tests** * Added tests validating quantized normalization and fused-residual flows across dtypes, batch sizes, scaling modes, and PDL configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Devashish Lal <laldevashish@gmail.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
these kernels are faster for all benchmarks when compared against aot sglang, fused flashinfer (cutedsl) and unfused impl Signed-off-by: Devashish Lal <devcode@fb.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
dc0d2f3 to
399e2bb
Compare
<!-- .github/pull_request_template.md --> ## 📌 Description FP8 model inference requires multiple intermediate quantization kernels, which can be avoided by fusing norm and quantization kernels. Consumers like sglang and vllm can lower to these norm + quant fusion kernels using custom torch compile passes ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ### Reference I have been working on adding custom fusion passes to sglang as part of the following [RFC](sgl-project/sglang#10118) and would like to use flashinfer's norm kernels for the norm quant fusions instead of migrating vllm kernels to sglang as part of the following [MR](sgl-project/sglang#10549) ### Implementation I realise that existing kernels (at least for rmsnorm) can be modified to add the scale parameter as an optional parameter, thereby avoiding most code duplication. However, as an initial implementation, I have opted for a separate implementation route. This can be refactored if required. For fused_add_rmsnorm_quant, I don't think an in-place update would be possible since dtypes for input and output differ Currently, FP8_E3M4 numeric limits (448) have been hard-coded, as I am not aware of getting this value at compile time without including c10 headers from torch, and not sure if that is acceptable post tvm ffi migration Following is a snippet from VLLM, and I have seen similar code for getting the FP8 numeric limits ```cpp #include <c10/util/Float8_e4m3fn.h> template <typename T, typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || std::is_same_v<T, c10::Float8_e4m3fnuz> || std::is_same_v<T, int8_t>>> struct quant_type_max { static constexpr T val() { return std::numeric_limits<T>::max(); } }; ``` The best option in my mind is to introduce `include/flashinfer/fp8.h` containing something similar to the above snippet, and also support e5m2 ### Tests atol and rtol for the fp8 assertions had to be high due to the low precision nature of the data, but with tolerances of 1e-2, just a few tests fail with a single element mismatch <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantized RMSNorm and fused quantized RMSNorm (residual-add) with configurable scale, eps, and PDL toggle. * Supports FP16/FP8 paths and optional per-token or per-tensor scaling; outputs are clamped for quantized formats. * **Tests** * Added tests validating quantized normalization and fused-residual flows across dtypes, batch sizes, scaling modes, and PDL configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Devashish Lal <laldevashish@gmail.com>
Signed-off-by: Devashish Lal <devcode@fb.com>
399e2bb to
24e868a
Compare
fab166d to
e891a27
Compare
Signed-off-by: Devashish Lal <devcode@fb.com>
e891a27 to
b16861d
Compare
Motivation
Initial implementation of the changes proposed in #10118
Modifications
This PR adds the fusion passes and integration tests for them
Passes added
For fusion passes to work with cuda graph runner I had to get rid for the model patching (or I could rewrite the pass with the pattern functions looking for pure pytorch code, we should avoid this model patching as it will interfere with the compilation process)
I have also added
model_bench.py, the idea with this is to provide a stripped down sglang runtime where each layer can be instantiated in isolation helping write integration and accuracy tests from fusion passes and fused kernelsAccuracy Tests
Benchmarking and Profiling
For llama 3.1 8B FP8, BS1, ISL 1024, OSL 1024. 6.2% Gains
Logs
MM + Silu and Mul fusion
MM + Silu and Mul + Quant (I have a small diff to use sgl_per_tensor_quant_fp8 for quant instead of the triton quant kernel, will add support for the default quant kernel before merge)
Checklist