Skip to content

Fix autotuner crash when input tensor is None#2756

Merged
samuellees merged 5 commits intoflashinfer-ai:mainfrom
he-yufeng:fix/autotuner-none-tensor
Mar 30, 2026
Merged

Fix autotuner crash when input tensor is None#2756
samuellees merged 5 commits intoflashinfer-ai:mainfrom
he-yufeng:fix/autotuner-none-tensor

Conversation

@he-yufeng
Copy link
Copy Markdown
Contributor

@he-yufeng he-yufeng commented Mar 11, 2026

Fixes #2749.

trtllm_fp8_block_scale_routed_moe passes routing_logits=None for non-routed calls, but _prepare_input_tensors assumes all inputs are tensors and crashes in _create_tensor_like trying to access .dtype on None.

Fix: skip None inputs and preserve them as-is. This matches the existing pattern in _prepare_input_tensors_with_batches which already handles non-tensor inputs with isinstance(t, torch.Tensor) checks.

Summary by CodeRabbit

  • Bug Fixes

    • Preserve missing inputs during input preparation so None entries are retained and not treated as tensors, preventing errors when some inputs are absent.
    • Relax and align routing and token-count validations to allow empty routing data when appropriate and ensure checks use actual token counts for consistency.
  • Tests

    • Add regression tests verifying None input handling and safe fallback behavior when routing/tuning data is absent.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Preserve None-valued optional inputs in autotuner input preparation and relax fused MoE routing validations by deriving token counts from hidden_states and allowing absent or empty routing-related tensors.

Changes

Cohort / File(s) Summary
Autotuner input handling
flashinfer/autotuner.py
_prepare_input_tensors now preserves None entries in inputs (appends None when encountered) and only calls _create_tensor_like when the input is non-None and the profile uses DynamicDim, preventing AttributeError on optional inputs like routing_logits.
Fused MoE routing & validation
flashinfer/fused_moe/core.py
Derive num_tokens from hidden_states.shape[0] instead of routing_logits.shape[0]; make shape/assertion checks for routing_logits, topk_ids, and expert_weights conditional so absent or empty routing tensors are accepted.
Tests (regression)
tests/autotuner/test_autotuner_core.py
Add tests validating _prepare_input_tensors handles None entries and that choose_one(...) in no-tuning/inference path tolerates None inputs and returns the provided runner with tactic -1 when applicable.

Sequence Diagram(s)

(Skipped — changes are bug fixes and validation relaxations that do not introduce a new multi-component control flow requiring visualization.)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119
  • cyx-6
  • bkryu
  • jimmyzho
  • nv-yunzheq

Poem

🐇 I nudged a None beneath a log,

no crash, just hush beneath the bog,
counted tokens from hidden light,
routing gaps tucked out of sight,
autotune hops home through the fog.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main fix—handling None input tensors in the autotuner to prevent crashes.
Description check ✅ Passed The PR description provides a clear explanation of the issue, the root cause, and the fix, matching the template structure with linked issue reference.
Linked Issues check ✅ Passed The code changes fully address issue #2749: autotuner now skips None inputs [autotuner.py], MoERunner derives num_tokens from hidden_states instead of routing_logits [fused_moe/core.py], and regression tests validate both code paths [test_autotuner_core.py].
Out of Scope Changes check ✅ Passed All changes are directly scoped to fixing the None tensor handling issue: autotuner input preparation, MoERunner shape assertions, and corresponding test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1


ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 792a36d6-79c6-4bd0-944d-642934e658c4

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and af6cb712bf09e609ae590e65672a290eff2b2e42.

📒 Files selected for processing (1)
  • flashinfer/autotuner.py

Comment thread flashinfer/autotuner.py
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the autotuner's input preparation logic that caused a crash when optional tensor inputs were None. By introducing a check to gracefully handle and preserve None values, the change enhances the robustness of the autotuner, preventing failures in scenarios where certain inputs are intentionally omitted.

Highlights

  • Bug Fix: Resolved a crash in the autotuner's _prepare_input_tensors function that occurred when None was passed as an input tensor, specifically for optional tensors like routing_logits in non-routed MoE calls.
  • Input Handling: Modified the _prepare_input_tensors function to explicitly check for and preserve None inputs, aligning its behavior with _prepare_input_tensors_with_batches which already handles non-tensor inputs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/autotuner.py
    • Added a check in _prepare_input_tensors to handle None inputs gracefully, preventing crashes.
Activity
  • No specific activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request effectively addresses a critical crash by correctly handling None inputs in the _prepare_input_tensors function. The change ensures that optional tensors passed as None are preserved as-is, preventing _create_tensor_like from attempting to access attributes on a None object. This improves the robustness of the autotuner's input preparation process.

Comment thread flashinfer/autotuner.py
# Some callers pass None for optional tensors (e.g. routing_logits
# in non-routed MoE). Preserve None as-is.
tensors.append(None)
elif any(isinstance(d, DynamicDim) for d in p):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current fix correctly handles None inputs. However, the _create_tensor_like function expects origin_tensor to be a torch.Tensor. If inputs[i] is not None but also not a torch.Tensor (e.g., a Python scalar like an int or float), and p contains DynamicDim, calling _create_tensor_like with a non-tensor object will still lead to a crash (e.g., when trying to access .dtype). To ensure robustness and align with the pattern in _prepare_input_tensors_with_batches that uses isinstance(t, torch.Tensor) checks for non-tensor inputs, the elif condition should explicitly check if inputs[i] is a torch.Tensor before attempting to create a tensor-like object.

Suggested change
elif any(isinstance(d, DynamicDim) for d in p):
elif isinstance(inputs[i], torch.Tensor) and any(isinstance(d, DynamicDim) for d in p):

@trevor-m
Copy link
Copy Markdown
Contributor

Hi @he-yufeng I tried your PR, but now I get this error:

  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py", line 333, in fused_experts_none_to_flashinfer_trtllm_fp8
    output = trtllm_fp8_block_scale_routed_moe(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 2478, in trtllm_fp8_block_scale_routed_moe
    return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 1683, in trtllm_fp8_block_scale_moe_op
    _, tactic = tuner.choose_one(
                ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/autotuner.py", line 480, in choose_one
    valid_tactics = r.get_valid_tactics(tensors, p)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 1043, in get_valid_tactics
    num_tokens = routing_logits.shape[0]
                 ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'shape'

@he-yufeng he-yufeng requested a review from IwakuraRein as a code owner March 12, 2026 06:34
@he-yufeng
Copy link
Copy Markdown
Contributor Author

Thanks for testing @trevor-m! The crash in get_valid_tactics was caused by the same root issue — MoERunner.get_valid_tactics() and MoERunner.forward() both used routing_logits.shape[0] to get num_tokens, which fails when routing_logits is None.

Fixed in 9043934:

  • Changed both methods to use hidden_states.shape[0] (always available)
  • Guarded the shape assertions for topk_ids/expert_weights that can be empty(0) tensors in pre-computed routing mode

Could you try again with the latest commit?

@trevor-m
Copy link
Copy Markdown
Contributor

@he-yufeng Thanks, it's working now

Copy link
Copy Markdown
Collaborator

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a smoke test for the fix?

Comment thread flashinfer/fused_moe/core.py Outdated
)
# topk_ids/expert_weights can be empty(0) when routing_logits is provided,
# or real tensors when pre-computed routing is used.
if topk_ids.numel() > 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expert_weights is checked like if expert_weights is not None and expert_weights.numel() > 0:

Could you keep a similar check style for topk_ids, please?

@he-yufeng
Copy link
Copy Markdown
Contributor Author

Good catch, updated topk_ids check to match the expert_weights style.

@samuellees
Copy link
Copy Markdown
Collaborator

Good catch, updated topk_ids check to match the expert_weights style.

Thanks @he-yufeng ! Could you add a smoke test for your code path, please? I believe the PR will be moved forward very fast once the test is ready ^ ^

@he-yufeng
Copy link
Copy Markdown
Contributor Author

Added two smoke tests in test_autotuner_core.py — one for _prepare_input_tensors and one for choose_one, both with a None optional tensor. Thanks for the nudge!

Copy link
Copy Markdown
Collaborator

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @he-yufeng Could you please resolve the conflict with main branch? Thanks

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been created, and the CI pipeline #46957861 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46957861: 1/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been created, and the CI pipeline #47022144 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47022144: 1/20 passed

trtllm_fp8_block_scale_routed_moe passes routing_logits=None for
non-routed calls, but _prepare_input_tensors assumes all inputs are
tensors and crashes with AttributeError: 'NoneType' has no attribute
'dtype' in _create_tensor_like.

Skip None inputs and preserve them as-is, matching the existing
pattern in _prepare_input_tensors_with_batches which already handles
non-tensor inputs gracefully.

Fixes flashinfer-ai#2749
get_valid_tactics() and forward() both accessed routing_logits.shape[0]
to get num_tokens, but routing_logits is None when pre-computed routing
is used (trtllm_fp8_block_scale_routed_moe passes routing_logits=None).

Use hidden_states.shape[0] instead, which is always available.
Also guard the shape assertions for topk_ids/expert_weights that can be
empty(0) tensors depending on the routing mode.
Cover the _prepare_input_tensors and choose_one paths when an optional
tensor (e.g. routing_logits in non-routed MoE) is None, which previously
caused AttributeError on .dtype/.shape.
@he-yufeng he-yufeng force-pushed the fix/autotuner-none-tensor branch from 1f2473b to 260ee5e Compare March 26, 2026 10:25
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been updated with latest changes, and the CI pipeline #47045960 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47045960: 12/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been created, and the CI pipeline #47074819 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47074819: 13/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been created, and the CI pipeline #47092245 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47092245: 11/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

samuellees commented Mar 28, 2026

Hi @he-yufeng , the CI seems passed. Some error are un-relative with this PR.
But could you take a look at the pre-commit check fail?
https://github.com/flashinfer-ai/flashinfer/actions/runs/23589405658/job/68979382814?pr=2756

This blocks some other test cases. You can run pre-commit this way:

| Run linting | `pre-commit run -a` |

Please let me know if you meet any question~

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !461 has been updated with latest changes, and the CI pipeline #47225408 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47225408: 11/20 passed

@samuellees samuellees enabled auto-merge (squash) March 30, 2026 12:53
@samuellees samuellees merged commit a6796a4 into flashinfer-ai:main Mar 30, 2026
41 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Autotuning fails with trtllm_fp8_block_scale_routed_moe

5 participants