feat: Add C++ native tree traversal for speculative decoding with xgrammar#15465
feat: Add C++ native tree traversal for speculative decoding with xgrammar#15465Ubospica wants to merge 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a critical performance optimization for speculative decoding within constrained grammar by integrating a C++ native implementation for tree traversal. This change aims to drastically reduce processing times, especially for complex tree structures, while maintaining full compatibility and correctness through a well-defined fallback mechanism. The enhancements are validated by comprehensive benchmarks demonstrating significant speedups and thorough unit tests. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a C++ native implementation for tree traversal to accelerate speculative decoding, showing impressive performance gains. The changes are well-structured, including a dispatcher, benchmarks, and tests. My review focuses on improving the benchmark's realism and maintainability, enhancing test code efficiency and consistency, and flagging a potential dependency risk with the use of an internal API from xgrammar.
| try: | ||
| from xgrammar.testing import _traverse_draft_tree as _traverse_draft_tree_native | ||
| except ImportError: | ||
| _traverse_draft_tree_native = None |
There was a problem hiding this comment.
The native C++ traversal function is being imported from xgrammar.testing. Using components from a library's testing module is risky, as these are typically considered internal, unstable APIs. They can be changed or removed in minor versions without notice, which could break this functionality in the future. It would be more robust to depend on a public, stable API from xgrammar. Please consider coordinating with the xgrammar maintainers to expose this function through a public interface.
There was a problem hiding this comment.
This PR first focuses on the speed enhancement. On the xgrammar side, the _traverse_draft_tree API would be stable.
| # Build tree: node 0 -> topk children -> topk grandchildren... | ||
| idx = 1 | ||
| level_nodes = [0] | ||
| draft_tokens[0] = token_ids[0] # First token: "{" | ||
|
|
||
| for step in range(spec_steps): | ||
| if idx >= spec_draft_tokens: | ||
| break | ||
| next_level = [] | ||
| for parent in level_nodes: | ||
| if idx >= spec_draft_tokens: | ||
| break | ||
| retrieve_next_token[parent] = idx | ||
| for k in range(spec_topk): | ||
| if idx >= spec_draft_tokens: | ||
| break | ||
| next_level.append(idx) | ||
| if k < spec_topk - 1 and idx + 1 < spec_draft_tokens: | ||
| retrieve_next_sibling[idx] = idx + 1 | ||
| draft_tokens[idx] = token_ids[step + 1] | ||
| idx += 1 | ||
| level_nodes = next_level |
There was a problem hiding this comment.
The current implementation of create_tree generates a tree where all paths are identical prefixes of the same JSON string. While this is useful for measuring raw traversal speed, it doesn't fully exercise the grammar logic across different valid branches. To create a more realistic benchmark that better reflects real-world usage with constrained decoding, consider diversifying the tokens at each branch. For example, you could construct the tree such that sibling nodes represent different valid keys or values in the JSON structure.
| # Benchmark Python fallback | ||
| for _ in range(num_warmup): | ||
| grammar = create_grammar(compiled) | ||
| grammar.accept_token(first_token) # Accept first draft token before traversal | ||
| bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size) | ||
| traverse_tree_fallback( | ||
| retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, bitmask | ||
| ) | ||
|
|
||
| python_times = [] | ||
| for _ in range(num_iter): | ||
| grammar = create_grammar(compiled) | ||
| grammar.accept_token(first_token) # Accept first draft token before traversal | ||
| bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size) | ||
| start = time.perf_counter() | ||
| traverse_tree_fallback( | ||
| retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, bitmask | ||
| ) | ||
| python_times.append((time.perf_counter() - start) * 1000) | ||
|
|
||
| python_avg = sum(python_times) / len(python_times) | ||
| python_bitmask = bitmask.clone() | ||
|
|
||
| # Benchmark C++ native | ||
| if _traverse_draft_tree_native is None: | ||
| print("C++ native not available!") | ||
| return | ||
|
|
||
| for _ in range(num_warmup): | ||
| grammar = create_grammar(compiled) | ||
| grammar.accept_token(first_token) | ||
| bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size) | ||
| _traverse_draft_tree_native( | ||
| retrieve_next_token, | ||
| retrieve_next_sibling, | ||
| draft_tokens, | ||
| grammar.matcher, | ||
| bitmask, | ||
| ) | ||
|
|
||
| native_times = [] | ||
| for _ in range(num_iter): | ||
| grammar = create_grammar(compiled) | ||
| grammar.accept_token(first_token) | ||
| bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size) | ||
| start = time.perf_counter() | ||
| _traverse_draft_tree_native( | ||
| retrieve_next_token, | ||
| retrieve_next_sibling, | ||
| draft_tokens, | ||
| grammar.matcher, | ||
| bitmask, | ||
| ) | ||
| native_times.append((time.perf_counter() - start) * 1000) |
There was a problem hiding this comment.
There is significant code duplication between the benchmarking logic for the Python fallback and the C++ native implementation. The warmup and measurement loops are nearly identical. To improve maintainability and reduce redundancy, consider refactoring this logic into a helper function that takes the function to be benchmarked as an argument and handles the setup, warmup, and timing.
| def create_grammar(): | ||
| grammar = xgr.Grammar.builtin_json_grammar() | ||
| tokenizer_info = xgr.TokenizerInfo( | ||
| vocab, vocab_size=len(vocab), stop_token_ids=[] | ||
| ) | ||
| compiled = xgr.GrammarCompiler(tokenizer_info).compile_grammar(grammar) | ||
| matcher = xgr.GrammarMatcher(compiled) | ||
| return XGrammarGrammar(matcher, len(vocab), None, None) |
There was a problem hiding this comment.
The create_grammar helper function recompiles the grammar from scratch on every call. As noted in the PR description and benchmark code, grammar compilation is an expensive operation. To make the test more efficient, you could compile the grammar once within the test method and then have create_grammar reuse the compiled object to create new GrammarMatcher instances for each run.
| retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int32) | ||
| retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int32) | ||
| draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int32) |
There was a problem hiding this comment.
The tensors retrieve_next_token, retrieve_next_sibling, and draft_tokens are created with dtype=torch.int32. However, the corresponding benchmark script (bench_traverse_tree.py) uses dtype=torch.int64. For consistency and to prevent potential issues with the native C++ implementation which might expect a specific integer size, it's recommended to use torch.int64 in this test as well.
| retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int32) | |
| retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int32) | |
| draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int32) | |
| retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int64) | |
| retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int64) | |
| draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int64) |
fa4a23f to
6f14296
Compare
0289cc5 to
b76b80d
Compare
b76b80d to
29e9b3c
Compare
|
cc @hnyls2002 I have updated it to the latest main. Do you know how I can get the CI fixed? |
Signed-off-by: Ubospica <ubospica@gmail.com>
|
/tag-and-rerun-ci |
This PR adds support for xgrammar's native C++
_traverse_draft_treeimplementation to accelerate tree traversal during speculative decoding with constrained grammar.It also upgrades xgrammar to v0.1.29 which supports the
_traverse_draft_treefeature. See docs: https://xgrammar.mlc.ai/docs/_modules/xgrammar/testing.html#_traverse_draft_tree.Changes
traverse_tree_fallbackas a fallback for other backendsTestEagleConstrainedDecodingTopKto test tree structure scenarios (topk > 1)Benchmark
Benchmark Results
All results match between Python and C++ implementations.
Summary: C++ native implementation achieves 4-11x speedup over Python fallback, with higher speedup on larger tree sizes.