Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical crash in the Kimi K2.5 model when using distributed attention and speculative decoding. The core fix involves refining the multimodal input embedding processing logic to ensure correct tensor concatenation and handling under various forward modes. Additionally, a dedicated test case has been added to thoroughly validate the stability and functionality of the Kimi K2.5 model in this specific, previously problematic, configuration. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
Code Review
This pull request introduces multimodal input handling for the Llama Eagle3 model, specifically for forward_batch.mm_input_embeds in extend mode, and updates test configurations for the Kimi K25 model to include distributed attention. The review highlights a critical issue in the Llama Eagle3 model's forward method where an assert embeds is not None could cause a runtime error and is redundant given subsequent null checks. Additionally, it suggests cleaning up commented-out test configurations in test_kimi_k25.py to improve code clarity.
| and forward_batch.contains_mm_inputs() | ||
| and not forward_batch.forward_mode.is_draft_extend(include_v2=True) | ||
| ): | ||
| assert embeds is not None |
There was a problem hiding this comment.
The assert embeds is not None on this line is problematic. If forward_batch.mm_input_embeds (assigned on line 153) is None and the conditions on lines 154-157 are met, this assert will cause a runtime error. The subsequent if embeds is None: on line 163 already handles the case where embeds might be None, making this assert redundant and potentially dangerous.
There was a problem hiding this comment.
Code Review
This pull request modifies the llama_eagle3.py model to enhance how embeddings are handled, specifically for multi-modal inputs during extension. It also updates the test_kimi_k25.py to introduce a new test variant for TP8+DP8+MTP configurations, while commenting out older variants. Feedback suggests refactoring the nested embedding logic for improved readability and maintainability, and removing the commented-out test variants to keep the code clean.
| embeds = forward_batch.mm_input_embeds | ||
| if ( | ||
| forward_batch.forward_mode.is_extend() | ||
| and forward_batch.contains_mm_inputs() | ||
| and not forward_batch.forward_mode.is_draft_extend(include_v2=True) | ||
| ): | ||
| assert embeds is not None | ||
| embeds = torch.cat( | ||
| [embeds[:-1], self.embed_tokens(input_ids[-1].unsqueeze(0))] | ||
| ) | ||
| if embeds is None: | ||
| embeds = self.embed_tokens(input_ids) |
There was a problem hiding this comment.
The logic for determining the embeddings is a bit nested and could be hard to follow. Refactoring it into an if/elif/else structure would make the different cases clearer and improve readability and maintainability. I've also added a descriptive message to the assertion for better debugging.
| embeds = forward_batch.mm_input_embeds | |
| if ( | |
| forward_batch.forward_mode.is_extend() | |
| and forward_batch.contains_mm_inputs() | |
| and not forward_batch.forward_mode.is_draft_extend(include_v2=True) | |
| ): | |
| assert embeds is not None | |
| embeds = torch.cat( | |
| [embeds[:-1], self.embed_tokens(input_ids[-1].unsqueeze(0))] | |
| ) | |
| if embeds is None: | |
| embeds = self.embed_tokens(input_ids) | |
| if ( | |
| forward_batch.forward_mode.is_extend() | |
| and forward_batch.contains_mm_inputs() | |
| and not forward_batch.forward_mode.is_draft_extend(include_v2=True) | |
| ): | |
| embeds = forward_batch.mm_input_embeds | |
| assert embeds is not None, "mm_input_embeds must be provided for extend with mm_inputs" | |
| embeds = torch.cat( | |
| [embeds[:-1], self.embed_tokens(input_ids[-1].unsqueeze(0))] | |
| ) | |
| elif forward_batch.mm_input_embeds is not None: | |
| embeds = forward_batch.mm_input_embeds | |
| else: | |
| embeds = self.embed_tokens(input_ids) |
|
cc @yhyang201 |
Motivation
Closes #21336
This issue is if there are mm input, draft embedding cannot handle it (will have out of boundary issue)
Modifications
Follows the same pattern in https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen3_5_mtp.py#L129-L131
to fix mm input embedding
Accuracy Tests
gsm8k 0.936, acc len consistently > 2.5 when using config of step=3, topk=1, draft_num_token=4
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci