[FEAT] Support GGUF format#2215
Conversation
|
lm_head.weight is directly used in many places, however, vllm changes it to be |
|
Thanks for the contributions. Can you fix the CI errors? |
5c616a5 to
2cffa70
Compare
How to trigger the CI? |
Pass lm_head to LogitsProcessor and check the weight inside |
|
@zhengy001 CI won't be triggered for you automatically because you are a first-time contributor. You can send a random typo fix PR and I can merge that for you so your future commits can trigger CI automatically. |
|
@zhengy001 Can you fix the CI errors? |
@merrymercy Sure, working on it. |
|
#2269 adds you as a new contributor so your future commits will trigger CI automatically |
@merrymercy :) |
There was a problem hiding this comment.
There won't be "lm_head.weight" if self.config.tie_word_embeddings is True
There was a problem hiding this comment.
Compared the result with vllm's. Pls suggest if there is a better way.
This reverts commit 883c955.
Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
Port the multi-CTA radix-based top-k kernel from flashinfer PR sgl-project#2215 (flashinfer-ai/flashinfer#2215) into sglang as a JIT-compiled kernel. This replaces the existing AOT single-CTA top-k implementation for NSA attention, providing better performance on long sequences (32K+) where the multi-CTA path activates. Key changes: - Add `python/sglang/jit_kernel/topk.py`: Python API exposing three JIT top-k variants (basic, page-table transform, ragged transform) with workspace management and lazy compilation via `cache_once`. - Add `python/sglang/jit_kernel/csrc/elementwise/topk.cuh`: CUDA wrapper providing TVM FFI entry points that dispatch to the flashinfer adaptive top-k kernels (TopKDispatch, TopKPageTableTransformDispatch, TopKRaggedTransformDispatch). - Add `python/sglang/jit_kernel/include/sgl_kernel/topk_fi.cuh`: Core CUDA implementation adapted from flashinfer, featuring: - 8-bit radix selection algorithm with multi-CTA support for large sequences (threshold configurable, default 32K) - Support for float32, float16, and bfloat16 input types - row_starts parameter for ragged input score layouts (sglang-specific) - Three output modes: indices-only, page-table lookup, and ragged offset addition - Update `python/sglang/srt/layers/attention/nsa_backend.py`: Switch NSA indexer to import from JIT kernel instead of AOT sgl_kernel. - Update `sgl-kernel/python/sgl_kernel/top_k.py`: Add JIT fallback path controlled by SGLANG_USE_JIT_TOPK env var (default enabled). When JIT is available, fast_topk_v2 / fast_topk_transform_fused / fast_topk_transform_ragged_fused transparently delegate to JIT kernels. - Add `sgl-kernel/tests/test_topk_jit.py`: Correctness tests covering basic, page-table, ragged, and trivial (length <= topk) cases across various batch sizes and sequence lengths up to 131K. - Add `sgl-kernel/benchmarks/bench_topk_jit.py`: Latency benchmark comparing JIT multi-CTA vs AOT single-CTA kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Motivation
#1616
Modifications
Support GGUF format
Checklist