Skip to content

Prevent MoE autotuner buffer overflow on large token buckets#3025

Merged
nv-yunzheq merged 1 commit intoflashinfer-ai:mainfrom
leejnau:cutedsl-moe-autotuner-buffer-overflow
Apr 10, 2026
Merged

Prevent MoE autotuner buffer overflow on large token buckets#3025
nv-yunzheq merged 1 commit intoflashinfer-ai:mainfrom
leejnau:cutedsl-moe-autotuner-buffer-overflow

Conversation

@leejnau
Copy link
Copy Markdown
Contributor

@leejnau leejnau commented Apr 9, 2026

📌 Description

CuteDslMoEWrapper pre-allocates intermediate buffers sized for
max_num_tokens, but the autotuner can probe buckets larger than that
(e.g. 8192 tokens vs 2048 max). The GEMM kernels then write past
buffer bounds, silently corrupting model weights and eventually
triggering cudaErrorIllegalAddress.

Two fixes:

  • Check num_tokens <= max_num_tokens before reusing pre-allocated
    buffers; fall back to dynamic allocation when exceeded.
  • Move tuning_config to instance level so dummy expert IDs span all
    local experts (randint(0, num_experts)) instead of a hardcoded 8,
    which concentrated routing and inflated permutation buffer sizes.

🔍 Related Issues

feat: cuteDSL fp4 moe for better DSR1 performance.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Performance Improvements

    • CUDA-graph preallocation now respects input-size limits to avoid over-allocation and improve memory behavior for larger token batches.
    • Memory gating refined to prevent oversized buffer reuse, improving inference stability under varying workloads.
  • Stability & Tuning

    • Auto-tuner configuration is now per-run, improving tuning accuracy and consistency.
    • Runtime tuning initialization adapts to the configured expert count for more representative profiling.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 42e5321a-152f-4e1d-869b-6017b2608a6d

📥 Commits

Reviewing files that changed from the base of the PR and between 4bb63033fdcf49e5c2837ee737cdf046ee26d34f and 127868d.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/fused_moe/cute_dsl/tuner.py

📝 Walkthrough

Walkthrough

Moved MoE autotuner config to per-run instances, tightened CUDA-graph preallocation gating to respect incoming token count and tile size, and made dummy expert-ID initialization dynamic to cover the configured number of experts at runtime.

Changes

Cohort / File(s) Summary
CUDA-graph preallocation & autotuner usage
flashinfer/fused_moe/cute_dsl/fused_moe.py
Added token-count check (x.shape[0] <= self.max_num_tokens) to CUDA-graph preallocation gating; updated tuner.choose_one calls to pass the runner instance's tuning_config (self._runner.tuning_config / runner.tuning_config) instead of the class-level config.
Tuning configuration refactor & tensor initializers
flashinfer/fused_moe/cute_dsl/tuner.py
Moved tuning_config from module/class scope into CuteDslFusedMoENvfp4Runner.__init__ as self.tuning_config; dynamic tensor initializers preserved but token_selected_experts now samples expert IDs with randint(0, max(num_experts, 1)) to span the configured expert count instead of a fixed 0..7 range.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • aleozlx
  • samuellees
  • IwakuraRein
  • jiahanc

Poem

🐇 I hop through configs, small and neat,
Per-run tuning now takes its seat.
Prealloc waits for tokens to play,
Experts sampled fresh each day.
A tiny leap — the model's fleet.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and specifically describes the primary fix: preventing buffer overflow in the MoE autotuner when processing large token buckets.
Description check ✅ Passed The description covers the main sections of the template: detailed problem description, two concrete fixes, related issue, and completed pre-commit/test checklists.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the MoE implementation by adding a check to ensure that pre-allocated buffers are only used when the current batch size is within the pre-allocated capacity, preventing potential buffer overflows. Additionally, it refactors the tuning configuration in CuteDslFusedMoENvfp4Runner from a class-level to an instance-level attribute, which allows for more accurate profiling by using the actual number of experts instead of a hardcoded value. I have no feedback to provide as no review comments were submitted.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !530 has been created, and the CI pipeline #48146400 is currently running. I'll report back once the pipeline job completes.

@nv-yunzheq nv-yunzheq force-pushed the cutedsl-moe-autotuner-buffer-overflow branch from 97bbb5e to 4bb6303 Compare April 9, 2026 19:58
@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot cancel

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

Unknown Command

Command /bot cancel is not recognized.

Use /bot help for available commands.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48146400 has been cancelled.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !530 has been updated with latest changes, and the CI pipeline #48147203 is currently running. I'll report back once the pipeline job completes.

),
tensor_initializers=[
# 0: x — FP4 quantized input (uint8 packed)
lambda shapes, dtype, device: torch.randint(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to fixed random see to guarantee consistent here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to fixed random see to guarantee consistent here?

No, since these are throwaway dummy tensors for timing kernel execution during profiling. The values don't affect tactic selection.

@nv-yunzheq nv-yunzheq force-pushed the cutedsl-moe-autotuner-buffer-overflow branch from 4bb6303 to 127868d Compare April 10, 2026 17:27
@nv-yunzheq nv-yunzheq merged commit a1166dc into flashinfer-ai:main Apr 10, 2026
31 of 35 checks passed
@nvpohanh
Copy link
Copy Markdown
Contributor

@leejnau Thanks for fixing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants