Skip to content

fix: use current CUDA device instead of tp_rank for SymmDeviceMemory allocation#2662

Merged
aleozlx merged 5 commits intoflashinfer-ai:mainfrom
fzyzcjy:fix/allreduce-device-idx
Mar 24, 2026
Merged

fix: use current CUDA device instead of tp_rank for SymmDeviceMemory allocation#2662
aleozlx merged 5 commits intoflashinfer-ai:mainfrom
fzyzcjy:fix/allreduce-device-idx

Conversation

@fzyzcjy
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy commented Mar 1, 2026

Summary

  • trtllm_create_ipc_workspace_for_all_reduce_fusion passes tp_rank as the device_idx to SymmDeviceMemory, which internally calls cudaSetDevice(device_idx). This assumes tp_rank == gpu_device_id.
  • This breaks when the caller uses a non-zero base GPU offset (e.g. SGLang's --base-gpu-id 4 maps TP ranks 0–3 to cuda:4–7).
  • The mismatch causes SymmDeviceMemory to allocate buffers on the wrong GPU and clobber the active CUDA device context, leading to CUBLAS_STATUS_EXECUTION_FAILED on subsequent operations.
  • Same issue in the MNNVL path (MNNVLAllReduceFusionWorkspace) where mapping.local_rank is used as device index.

Fix

Use torch.cuda.current_device() — which reflects the actual GPU the caller has selected — instead of deriving the device index from the TP rank.

Repro

# Crashes with CUBLAS_STATUS_EXECUTION_FAILED
python -m sglang.launch_server --model-path Qwen/Qwen3-30B-A3B --tp 4 --base-gpu-id 4

# Works (because GPUs appear as 0-3 internally)
CUDA_VISIBLE_DEVICES=4,5,6,7 python -m sglang.launch_server --model-path Qwen/Qwen3-30B-A3B --tp 4

Test plan

  • Verified on 8×H200 node: --base-gpu-id 4 --tp 4 crashes without fix, passes with fix
  • Existing CI tests should still pass (no behavior change when tp_rank == device_id, which is the common case)

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Fixed device assignment in all-reduce fusion operations to correctly handle cases where CUDA device indices don't align with tensor parallelism ranks.
    • Improved device binding for multicast-related buffer allocations to use the current CUDA device accurately.
  • Tests

    • Added test coverage for all-reduce fusion with GPU offset scenarios.

…allocation

`trtllm_create_ipc_workspace_for_all_reduce_fusion` passes `tp_rank` as
the `device_idx` to `SymmDeviceMemory`, which internally calls
`cudaSetDevice(device_idx)`. This assumes `tp_rank == gpu_device_id`,
which breaks when the caller uses a non-zero base GPU offset (e.g.
SGLang's `--base-gpu-id 4` maps TP ranks 0-3 to cuda:4-7).

The mismatch causes SymmDeviceMemory to allocate buffers on the wrong
GPU and clobber the active CUDA device context, leading to
`CUBLAS_STATUS_EXECUTION_FAILED` on subsequent operations.

Fix both the TRTLLM and MNNVL workspace creation paths to use
`torch.cuda.current_device()` — which reflects the actual GPU the
caller has selected — instead of deriving the device index from the
TP rank.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in distributed GPU environments where the system's base-gpu-id is not zero. It ensures that memory allocations for all-reduce fusion workspaces correctly identify and utilize the currently active CUDA device, preventing crashes and ensuring proper operation when GPU device IDs do not directly correspond to tensor parallelism ranks.

Highlights

  • Corrected CUDA Device Selection: Modified trtllm_create_ipc_workspace_for_all_reduce_fusion and MNNVLAllReduceFusionWorkspace to use torch.cuda.current_device() instead of tp_rank or local_rank for SymmDeviceMemory allocation.
  • Resolved base-gpu-id Incompatibility: Fixed an issue where non-zero base-gpu-id configurations (e.g., in SGLang) caused CUBLAS_STATUS_EXECUTION_FAILED due to incorrect GPU device indexing and memory allocation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/comm/trtllm_ar.py
    • Replaced torch.device("cuda", tp_rank).index with torch.cuda.current_device() when creating the IPC workspace.
  • flashinfer/comm/trtllm_mnnvl_ar.py
    • Changed torch.device("cuda", mapping.local_rank) to torch.device("cuda", torch.cuda.current_device()) for McastGPUBuffer allocation.
    • Added a comment explaining the change to support base_gpu_id != 0 scenarios.
Activity
  • The author verified the fix on an 8xH200 node, confirming that the --base-gpu-id 4 --tp 4 configuration, which previously crashed, now passes with the applied changes.
  • Existing CI tests are expected to pass without behavior changes for common tp_rank == device_id cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 1, 2026

📝 Walkthrough

Walkthrough

The PR fixes device identifier handling in all-reduce and multicast operations by using torch.cuda.current_device() instead of rank-based device derivation, and extends test coverage to validate GPU offset scenarios where CUDA device indices differ from tensor parallelism ranks.

Changes

Cohort / File(s) Summary
Device Binding Fixes
flashinfer/comm/trtllm_ar.py, flashinfer/comm/trtllm_mnnvl_ar.py
Changed device identifier from rank-based derivation to torch.cuda.current_device() for SymmDeviceMemory and McastGPUBuffer allocation when base_gpu_id offset is present.
Test Coverage Enhancement
tests/comm/test_trtllm_allreduce_fusion.py
Added gpu_offset parameter to worker execution and multi_process_parallel to support testing AllReduce fusion with non-aligned device indices; introduced new test function test_trtllm_allreduce_fusion_gpu_offset.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • #2118 — Modifies device/buffer binding in trtllm_mnnvl_ar.py for multicast GPU buffer construction.
  • #1991 — Adjusts CUDA device binding in trtllm_create_ipc_workspace_for_all_reduce_fusion within trtllm_ar.py.
  • #2239 — Touches SymmDeviceMemory and McastGPUBuffer usage patterns refined by these device-binding fixes.

Suggested labels

op: comm

Suggested reviewers

  • yzh119
  • nvmbreughe
  • wenscarl
  • bkryu

Poem

🐰 When GPUs dance in ranks and rows,
But indices don't quite compose,
A rabbit hops with cuda-cure,
Now offsets make the flow secure! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main fix: using current CUDA device instead of tp_rank for SymmDeviceMemory allocation, which directly addresses the core problem in the changeset.
Description check ✅ Passed The description is comprehensive and well-structured, explaining the problem, root cause, fix, reproducer, and test plan. It exceeds the basic template requirements with clear technical context.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where the tensor parallelism rank (tp_rank) or local rank was incorrectly assumed to be the CUDA device index. This assumption fails in scenarios with a non-zero base GPU offset, leading to memory allocation on the wrong device and subsequent CUDA errors. The fix correctly replaces the usage of tp_rank and mapping.local_rank with torch.cuda.current_device() to ensure that memory is allocated on the actual active CUDA device for the process. The changes in flashinfer/comm/trtllm_ar.py and flashinfer/comm/trtllm_mnnvl_ar.py are accurate and effectively resolve the described issue. I've added one suggestion to improve code clarity by adding a comment for consistency.

Comment thread flashinfer/comm/trtllm_ar.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 135-143: The device used to create the workspace is inconsistent:
McastGPUBuffer is constructed with torch.device("cuda",
torch.cuda.current_device()) but self.buffer_flags is allocated using
mapping.local_rank, which can point to a different CUDA device when base_gpu_id
!= 0; update the allocation of self.buffer_flags to use the same CUDA device as
McastGPUBuffer (use torch.cuda.current_device() or the previously built
torch.device) instead of mapping.local_rank so both workspace objects reference
the same GPU (look for McastGPUBuffer construction and the code that allocates
self.buffer_flags).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f521fe1 and 1ab6c9f.

📒 Files selected for processing (2)
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py

Comment thread flashinfer/comm/trtllm_mnnvl_ar.py
- Add explanatory comment for torch.cuda.current_device() usage in
  trtllm_ar.py (gemini review)
- Fix buffer_flags allocation to also use current_device() instead of
  mapping.local_rank (coderabbit review)
fzyzcjy added a commit to radixark/miles that referenced this pull request Mar 1, 2026
…educe bug

FlashInfer allreduce fusion crashes with base_gpu_id != 0 (colocate
mode). Enable DP attention so the auto-enable logic is skipped.
See flashinfer-ai/flashinfer#2662
fzyzcjy added 3 commits March 1, 2026 17:36
Test validates that allreduce fusion works correctly when the CUDA
device index differs from the TP rank, simulating sglang colocate mode
where inference engines run on non-zero base GPUs (e.g. GPUs 4-7 with
TP ranks 0-3).

Tests both legacy and unified APIs with CUDA graph capture.
Instead of a separate test file, extend the existing worker and
multi_process_parallel with a gpu_offset parameter (default 0, no
behavior change for existing tests).

New test_trtllm_allreduce_fusion_gpu_offset reuses the same worker
with gpu_offset = available_gpus - world_size, validating that
allreduce fusion works when CUDA device index != TP rank.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/comm/test_allreduce_fusion_gpu_offset.py (1)

255-262: Use iterable unpacking for proc_args (RUF005).

Ruff flagged tuple concatenation here; unpacking is cleaner and avoids the lint warning.

♻️ Suggested refactor
-        proc_args: tuple[Any, ...] = (
-            world_size,
-            i,
-            gpu_offset,
-            dtype,
-            hidden_dim,
-            distributed_init_port,
-        ) + target_args
+        proc_args: tuple[Any, ...] = (
+            world_size,
+            i,
+            gpu_offset,
+            dtype,
+            hidden_dim,
+            distributed_init_port,
+            *target_args,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_allreduce_fusion_gpu_offset.py` around lines 255 - 262,
Replace the tuple concatenation that builds proc_args with iterable unpacking:
instead of joining two tuples, use the starred-unpacking form to include
target_args directly (update the proc_args assignment where proc_args and
target_args are used in tests/comm/test_allreduce_fusion_gpu_offset.py). Ensure
proc_args = (world_size, i, gpu_offset, dtype, hidden_dim,
distributed_init_port, *target_args) so Ruff RUF005 is satisfied and behavior is
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/comm/test_allreduce_fusion_gpu_offset.py`:
- Around line 1-314: The test file was reformatted by ruff-format; run the
project's formatter (pre-commit / ruff-format) on this file and commit the
changes so CI passes. Specifically, run the formatter (or `pre-commit run
--all-files` / `ruff format tests/comm/test_allreduce_fusion_gpu_offset.py`),
review the changes affecting functions like _worker_with_gpu_offset,
_run_multi_process, and test_allreduce_fusion_gpu_offset, stage and commit the
reformatted file, and push the commit.
- Around line 51-83: Initialize ipc_handles and workspace to None before the
setup try-block and guard all teardown/cleanup code to only act if those
variables are not None (references: ipc_handles, workspace,
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion,
comm.create_allreduce_fusion_workspace). Also protect the unconditional teardown
barrier/torch.distributed calls by checking torch.distributed.is_available() and
torch.distributed.is_initialized() (and that world_size>1) or wrap the barrier
in a try/except to avoid hangs when another rank has failed; apply the same
guards to the other similar block around the code referenced for lines 220-228.
- Around line 272-293: Add an architecture-based skip at the start of
test_allreduce_fusion_gpu_offset to avoid RuntimeError when calling allreduce on
unsupported GPUs: import and call the flashinfer.utils helper (e.g.,
is_sm80_supported) and call pytest.skip(...) if it returns False. Place this
check before any CUDA seeds or device-count logic so the test is skipped early;
reference the test function name (test_allreduce_fusion_gpu_offset) and the
backend-protected function (allreduce) to clarify why the helper is required.

---

Nitpick comments:
In `@tests/comm/test_allreduce_fusion_gpu_offset.py`:
- Around line 255-262: Replace the tuple concatenation that builds proc_args
with iterable unpacking: instead of joining two tuples, use the
starred-unpacking form to include target_args directly (update the proc_args
assignment where proc_args and target_args are used in
tests/comm/test_allreduce_fusion_gpu_offset.py). Ensure proc_args = (world_size,
i, gpu_offset, dtype, hidden_dim, distributed_init_port, *target_args) so Ruff
RUF005 is satisfied and behavior is unchanged.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 61e6516 and 3c628a5.

📒 Files selected for processing (1)
  • tests/comm/test_allreduce_fusion_gpu_offset.py

Comment thread tests/comm/test_allreduce_fusion_gpu_offset.py Outdated
Comment thread tests/comm/test_allreduce_fusion_gpu_offset.py Outdated
Comment thread tests/comm/test_allreduce_fusion_gpu_offset.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/comm/test_trtllm_allreduce_fusion.py (1)

437-447: Consider iterable unpacking for cleaner tuple construction.

The static analyzer flags the tuple concatenation. Iterable unpacking improves readability and avoids nested parentheses.

♻️ Suggested refactor
         proc_args = (
-            (
-                world_size,
-                i,
-                dtype,
-                hidden_dim,
-                distributed_init_port,
-            )
-            + target_args
-            + (gpu_offset,)
+            world_size,
+            i,
+            dtype,
+            hidden_dim,
+            distributed_init_port,
+            *target_args,
+            gpu_offset,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/comm/test_trtllm_allreduce_fusion.py` around lines 437 - 447, Replace
the tuple concatenation used to build proc_args with iterable unpacking to
improve readability: instead of joining tuples with +, construct proc_args by
listing the fixed elements (world_size, i, dtype, hidden_dim,
distributed_init_port), then unpack target_args with *target_args and append
gpu_offset as the final element; update the code that assigns proc_args
accordingly (look for the proc_args variable near target_args and gpu_offset).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/comm/test_trtllm_allreduce_fusion.py`:
- Around line 437-447: Replace the tuple concatenation used to build proc_args
with iterable unpacking to improve readability: instead of joining tuples with
+, construct proc_args by listing the fixed elements (world_size, i, dtype,
hidden_dim, distributed_init_port), then unpack target_args with *target_args
and append gpu_offset as the final element; update the code that assigns
proc_args accordingly (look for the proc_args variable near target_args and
gpu_offset).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3c628a5 and 0c80eba.

📒 Files selected for processing (1)
  • tests/comm/test_trtllm_allreduce_fusion.py

@fzyzcjy fzyzcjy added the run-ci label Mar 1, 2026
@aleozlx aleozlx self-assigned this Mar 2, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !364 has been created, and the CI pipeline #45169844 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #45169844: canceled

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 23, 2026

/bot run

@aleozlx aleozlx added run-ci and removed run-ci labels Mar 23, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !364 has been created, and the CI pipeline #46827216 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) March 24, 2026 02:18
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46827216: 13/20 passed

@aleozlx aleozlx merged commit b893192 into flashinfer-ai:main Mar 24, 2026
88 of 125 checks passed
@aleozlx aleozlx linked an issue Apr 1, 2026 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] AllReduceFusion doesn't support multiple distributed groups

3 participants