Skip to content

feat: Add C++ native tree traversal for speculative decoding with xgrammar#15465

Open
Ubospica wants to merge 2 commits intosgl-project:mainfrom
Ubospica:main-dev/2025-12-18-cpp-mask-gen
Open

feat: Add C++ native tree traversal for speculative decoding with xgrammar#15465
Ubospica wants to merge 2 commits intosgl-project:mainfrom
Ubospica:main-dev/2025-12-18-cpp-mask-gen

Conversation

@Ubospica
Copy link
Copy Markdown
Collaborator

@Ubospica Ubospica commented Dec 19, 2025

This PR adds support for xgrammar's native C++ _traverse_draft_tree implementation to accelerate tree traversal during speculative decoding with constrained grammar.

It also upgrades xgrammar to v0.1.29 which supports the _traverse_draft_tree feature. See docs: https://xgrammar.mlc.ai/docs/_modules/xgrammar/testing.html#_traverse_draft_tree.

Changes

  • Add automatic dispatch to C++ native implementation when xgrammar backend is available
  • Rename original Python implementation to traverse_tree_fallback as a fallback for other backends
  • Add benchmark script to compare Python fallback vs C++ native performance
  • Add unit test to verify native implementation matches fallback results
  • Add TestEagleConstrainedDecodingTopK to test tree structure scenarios (topk > 1)

Benchmark

bash benchmark/speculative/run_bench_traverse_tree.sh

Benchmark Results

Config Python Fallback C++ Native Speedup
steps=1, topk=4, tokens=4 0.147 ms 0.035 ms 4.17x
steps=5, topk=4, tokens=16 1.031 ms 0.121 ms 8.51x
steps=5, topk=4, tokens=64 4.254 ms 0.389 ms 10.92x

All results match between Python and C++ implementations.

Summary: C++ native implementation achieves 4-11x speedup over Python fallback, with higher speedup on larger tree sizes.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a critical performance optimization for speculative decoding within constrained grammar by integrating a C++ native implementation for tree traversal. This change aims to drastically reduce processing times, especially for complex tree structures, while maintaining full compatibility and correctness through a well-defined fallback mechanism. The enhancements are validated by comprehensive benchmarks demonstrating significant speedups and thorough unit tests.

Highlights

  • Performance Improvement: Introduced a C++ native implementation for _traverse_draft_tree to significantly accelerate tree traversal during speculative decoding with constrained grammar, achieving 4-11x speedup over the Python fallback.
  • Automatic Dispatch: Implemented automatic dispatch to the C++ native implementation when the xgrammar backend is available, ensuring optimal performance without manual configuration.
  • Python Fallback: Renamed the original Python implementation to traverse_tree_fallback to serve as a robust fallback for other backends where the C++ native version is not applicable.
  • Benchmarking and Testing: Added a dedicated benchmark script to compare the performance of the Python fallback and C++ native implementations, along with unit tests to verify that the native implementation produces identical results to the fallback, ensuring correctness. New test cases were also added for topk > 1 scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a C++ native implementation for tree traversal to accelerate speculative decoding, showing impressive performance gains. The changes are well-structured, including a dispatcher, benchmarks, and tests. My review focuses on improving the benchmark's realism and maintainability, enhancing test code efficiency and consistency, and flagging a potential dependency risk with the use of an internal API from xgrammar.

Comment on lines +39 to +42
try:
from xgrammar.testing import _traverse_draft_tree as _traverse_draft_tree_native
except ImportError:
_traverse_draft_tree_native = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The native C++ traversal function is being imported from xgrammar.testing. Using components from a library's testing module is risky, as these are typically considered internal, unstable APIs. They can be changed or removed in minor versions without notice, which could break this functionality in the future. It would be more robust to depend on a public, stable API from xgrammar. Please consider coordinating with the xgrammar maintainers to expose this function through a public interface.

Copy link
Copy Markdown
Collaborator Author

@Ubospica Ubospica Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR first focuses on the speed enhancement. On the xgrammar side, the _traverse_draft_tree API would be stable.

Comment on lines +55 to +76
# Build tree: node 0 -> topk children -> topk grandchildren...
idx = 1
level_nodes = [0]
draft_tokens[0] = token_ids[0] # First token: "{"

for step in range(spec_steps):
if idx >= spec_draft_tokens:
break
next_level = []
for parent in level_nodes:
if idx >= spec_draft_tokens:
break
retrieve_next_token[parent] = idx
for k in range(spec_topk):
if idx >= spec_draft_tokens:
break
next_level.append(idx)
if k < spec_topk - 1 and idx + 1 < spec_draft_tokens:
retrieve_next_sibling[idx] = idx + 1
draft_tokens[idx] = token_ids[step + 1]
idx += 1
level_nodes = next_level
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of create_tree generates a tree where all paths are identical prefixes of the same JSON string. While this is useful for measuring raw traversal speed, it doesn't fully exercise the grammar logic across different valid branches. To create a more realistic benchmark that better reflects real-world usage with constrained decoding, consider diversifying the tokens at each branch. For example, you could construct the tree such that sibling nodes represent different valid keys or values in the JSON structure.

Comment on lines +108 to +161
# Benchmark Python fallback
for _ in range(num_warmup):
grammar = create_grammar(compiled)
grammar.accept_token(first_token) # Accept first draft token before traversal
bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size)
traverse_tree_fallback(
retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, bitmask
)

python_times = []
for _ in range(num_iter):
grammar = create_grammar(compiled)
grammar.accept_token(first_token) # Accept first draft token before traversal
bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size)
start = time.perf_counter()
traverse_tree_fallback(
retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, bitmask
)
python_times.append((time.perf_counter() - start) * 1000)

python_avg = sum(python_times) / len(python_times)
python_bitmask = bitmask.clone()

# Benchmark C++ native
if _traverse_draft_tree_native is None:
print("C++ native not available!")
return

for _ in range(num_warmup):
grammar = create_grammar(compiled)
grammar.accept_token(first_token)
bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size)
_traverse_draft_tree_native(
retrieve_next_token,
retrieve_next_sibling,
draft_tokens,
grammar.matcher,
bitmask,
)

native_times = []
for _ in range(num_iter):
grammar = create_grammar(compiled)
grammar.accept_token(first_token)
bitmask = allocate_token_bitmask(spec_draft_tokens, vocab_size)
start = time.perf_counter()
_traverse_draft_tree_native(
retrieve_next_token,
retrieve_next_sibling,
draft_tokens,
grammar.matcher,
bitmask,
)
native_times.append((time.perf_counter() - start) * 1000)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the benchmarking logic for the Python fallback and the C++ native implementation. The warmup and measurement loops are nearly identical. To improve maintainability and reduce redundancy, consider refactoring this logic into a helper function that takes the function to be benchmarked as an argument and handles the setup, warmup, and timing.

Comment on lines +363 to +370
def create_grammar():
grammar = xgr.Grammar.builtin_json_grammar()
tokenizer_info = xgr.TokenizerInfo(
vocab, vocab_size=len(vocab), stop_token_ids=[]
)
compiled = xgr.GrammarCompiler(tokenizer_info).compile_grammar(grammar)
matcher = xgr.GrammarMatcher(compiled)
return XGrammarGrammar(matcher, len(vocab), None, None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The create_grammar helper function recompiles the grammar from scratch on every call. As noted in the PR description and benchmark code, grammar compilation is an expensive operation. To make the test more efficient, you could compile the grammar once within the test method and then have create_grammar reuse the compiled object to create new GrammarMatcher instances for each run.

Comment on lines +372 to +374
retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int32)
retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int32)
draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tensors retrieve_next_token, retrieve_next_sibling, and draft_tokens are created with dtype=torch.int32. However, the corresponding benchmark script (bench_traverse_tree.py) uses dtype=torch.int64. For consistency and to prevent potential issues with the native C++ implementation which might expect a specific integer size, it's recommended to use torch.int64 in this test as well.

Suggested change
retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int32)
retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int32)
draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int32)
retrieve_next_token = torch.tensor([1, 2, -1], dtype=torch.int64)
retrieve_next_sibling = torch.tensor([-1, -1, -1], dtype=torch.int64)
draft_tokens = torch.tensor([3, 6, 4], dtype=torch.int64)

@github-actions github-actions Bot added the dependencies Pull requests that update a dependency file label Dec 19, 2025
@Ubospica
Copy link
Copy Markdown
Collaborator Author

cc @hnyls2002 @merrymercy

@Ubospica Ubospica changed the title Add C++ native tree traversal for speculative decoding with xgrammar feat: Add C++ native tree traversal for speculative decoding with xgrammar Dec 19, 2025
@Ubospica Ubospica force-pushed the main-dev/2025-12-18-cpp-mask-gen branch from fa4a23f to 6f14296 Compare December 19, 2025 08:26
@Ubospica Ubospica force-pushed the main-dev/2025-12-18-cpp-mask-gen branch from 0289cc5 to b76b80d Compare January 1, 2026 06:36
Signed-off-by: Ubospica <ubospica@gmail.com>
@Ubospica Ubospica force-pushed the main-dev/2025-12-18-cpp-mask-gen branch from b76b80d to 29e9b3c Compare January 1, 2026 06:36
@Ubospica
Copy link
Copy Markdown
Collaborator Author

Ubospica commented Jan 1, 2026

cc @hnyls2002 I have updated it to the latest main. Do you know how I can get the CI fixed?

Signed-off-by: Ubospica <ubospica@gmail.com>
@hanming-lu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants