Skip to content

fix: TRT-LLM MHA CUDA illegal address with EAGLE v2 + DP attention#21649

Merged
Kangyan-Zhou merged 4 commits intosgl-project:mainfrom
Kangyan-Zhou:fix-trtllm-mha-dp-batch-size
Apr 5, 2026
Merged

fix: TRT-LLM MHA CUDA illegal address with EAGLE v2 + DP attention#21649
Kangyan-Zhou merged 4 commits intosgl-project:mainfrom
Kangyan-Zhou:fix-trtllm-mha-dp-batch-size

Conversation

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator

Summary

  • Fix CUDA_ERROR_ILLEGAL_ADDRESS in TRT-LLM FMHA kernel during EAGLE v2 speculative decoding with DP attention and multimodal inputs
  • Store batch_size in TRTLLMMHAMetadata at init time and use it in forward_extend, instead of forward_batch.batch_size which may be inflated by DP padding

Root Cause

When DP attention is enabled with EAGLE v2, prepare_mlp_sync_batch (forward_batch_info.py:891) inflates forward_batch.batch_size to match the max across DP ranks for MLP synchronization. For example, DP0 has 10 requests but batch_size gets inflated to 12 to match DP1/2/3.

However, init_forward_metadata has already computed metadata tensors (page_table, cache_seqlens, cu_seqlens_q/k) for the original batch_size of 10. The TRT-LLM FMHA kernel in forward_extend (line 874) was passing the inflated forward_batch.batch_size=12 while the metadata tensors only had 10 entries. The kernel iterates over 12 requests, reads indices 10 and 11 past the tensor boundaries, and hits unmapped GPU memory.

The TMA descriptors in fmhaKernels.cuh are configured with CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, so out-of-bounds access causes a hard CUDA_ERROR_ILLEGAL_ADDRESS rather than being clamped.

Other attention backends (FlashInfer native, Triton) are immune because they derive batch_size from metadata tensor shapes rather than using forward_batch.batch_size as an explicit parameter.

Reproduction

  • Model: Qwen3.5-397B-A17B-FP8 on 4x B200 NVL4 (GB300)
  • Config: --tp 4 --dp-size=4 --enable-dp-attention --speculative-algorithm=EAGLE --speculative-num-steps=3 --speculative-eagle-topk=1 --speculative-num-draft-tokens=4 --attention-backend=trtllm_mha --enable-multimodal --mamba-scheduler-strategy=extra_buffer --page-size=64
  • Trigger: MMMU-Pro VLM eval with 512 concurrent requests via NeMo Skills
  • Crash: CUDA_ERROR_ILLEGAL_ADDRESS at fmhaKernels.cuh:304 in _draft_extend_for_decodeforward_extend, consistently at ~388-405/500 questions
Stack trace (with CUDA_LAUNCH_BLOCKING=1)
CUDA Error: CUDA_ERROR_ILLEGAL_ADDRESS /workspace/include/flashinfer/trtllm/fmha/fmhaKernels.cuh 304

eagle_worker_v2.py:736  forward_batch_generation
eagle_worker_v2.py:591  _draft_extend_for_decode
model_runner.py:2561    forward_extend
qwen3_5_mtp.py:146      MTP draft model forward
qwen3_5.py:977           decoder layer
qwen3_5.py:833           self_attention
qwen3_5.py:814           gate = torch.sigmoid(gate)

torch.AcceleratorError: CUDA error: an illegal memory access was encountered

Verification

After fix: 1730/1730 MMMU-Pro questions completed without crash (accuracy 78.55%), on the exact config that previously crashed at ~388-405.

Server log from crash reproduction attached

Test plan

  • Reproduced crash on standalone pod with CUDA_LAUNCH_BLOCKING=1
  • Applied fix, re-ran 500/500 MMMU-Pro eval — no crash
  • Re-ran full 1730/1730 MMMU-Pro eval — no crash, accuracy 78.55%
  • CI test for EAGLE v2 + DP attention + trtllm_mha (requires SM100 hardware)

🤖 Generated with Claude Code

@github-actions github-actions Bot added the blackwell SM100/SM120 label Mar 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a batch_size field to the TRTLLMMHAMetadata class and ensures it is populated during metadata initialization to prevent issues with inflated batch sizes from DP padding. The reviewer pointed out that this initialization is missing in the init_forward_metadata_capture_cuda_graph function, which could still lead to illegal address errors when running with CUDA graphs.

metadata = TRTLLMMHAMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = forward_batch.batch_size
metadata.batch_size = batch_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This correctly stores the batch size for the non-CUDA graph path. However, the same logic appears to be missing for the CUDA graph path in init_forward_metadata_capture_cuda_graph.

In that function, metadata.batch_size is not set, so it will default to 0. When forward_extend is called in a CUDA graph context with DP attention, this will likely cause the same CUDA_ERROR_ILLEGAL_ADDRESS this PR aims to fix.

To ensure the fix is complete, please initialize metadata.batch_size in init_forward_metadata_capture_cuda_graph as well. For example, you could add metadata.batch_size = bs at the beginning of the function.

…ddress

When DP attention is enabled with EAGLE v2 speculative decoding,
`prepare_mlp_sync_batch` inflates `forward_batch.batch_size` to match
the max across DP ranks for MLP synchronization. However,
`init_forward_metadata` has already computed metadata tensors
(page_table, cache_seqlens, cu_seqlens_q/k) for the original,
smaller batch_size.

The TRT-LLM FMHA kernel in `forward_extend` was using the inflated
`forward_batch.batch_size`, causing it to read past the metadata tensor
boundaries. This triggers `CUDA_ERROR_ILLEGAL_ADDRESS` in
`fmhaKernels.cuh` when the kernel accesses invalid page table entries
via TMA descriptors (configured with OOB_FILL_NONE).

The fix stores `batch_size` in `TRTLLMMHAMetadata` at init time and
uses `self.forward_metadata.batch_size` in the kernel call, which
is the correct pre-padding value.

Reproduction:
- Model: Qwen3.5-397B-A17B-FP8 on 4x B200 NVL4 (GB300)
- Config: --tp 4 --dp-size=4 --enable-dp-attention
          --speculative-algorithm=EAGLE --attention-backend=trtllm_mha
          --enable-multimodal
- Trigger: 500+ concurrent multimodal (MMMU-Pro) requests
- Crash: CUDA_ERROR_ILLEGAL_ADDRESS at fmhaKernels.cuh:304
         in _draft_extend_for_decode -> forward_extend

Verified: 1730/1730 MMMU-Pro questions completed without crash
after fix (previously crashed at ~388-405).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Kangyan-Zhou Kangyan-Zhou force-pushed the fix-trtllm-mha-dp-batch-size branch from 67e8502 to e26b7cd Compare March 30, 2026 03:21
@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@Kangyan-Zhou Kangyan-Zhou marked this pull request as ready for review March 30, 2026 04:34
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that when padding is introduced by DP attention, some information in the forward batch becomes inconsistent with the metadata. But in this case, shouldn’t the information from the forward batch be the correct one, since it reflects the padded state? Why does this code choose to use the metadata when the two are inconsistent?

The previous commit stored batch_size in TRTLLMMHAMetadata and used it
in forward_extend, but only set it in init_forward_metadata (non-CUDA-graph
path). init_forward_metadata_capture_cuda_graph left it at the default 0,
causing CUDA_ERROR_INVALID_VALUE during EAGLE v1 draft extend graph capture
with trtllm_mha backend.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

My intuition is that when padding is introduced by DP attention, some information in the forward batch becomes inconsistent with the metadata. But in this case, shouldn’t the information from the forward batch be the correct one, since it reflects the padded state? Why does this code choose to use the metadata when the two are inconsistent?

IIUC this is exactly the issue that causes IMA because the batch size increases which the trtllm mha kernel uses to access the data.

The CI failure https://github.com/sgl-project/sglang/actions/runs/23727613783/job/69114314668 is caused by a behavior diff between spec v1 and v2. Spec V2 does not capture draft cuda graph state for trtllm mha backend, while the test uses V1. Let me retrigger the test to see whether it can work now

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

IIUC this is exactly the issue that causes IMA because the batch size increases which the trtllm mha kernel uses to access the data.

The CI failure https://github.com/sgl-project/sglang/actions/runs/23727613783/job/69114314668 is caused by a behavior diff between spec v1 and v2. Spec V2 does not capture draft cuda graph state for trtllm mha backend, while the test uses V1. Let me retrigger the test to see whether it can work now

What I mean is that there may indeed be an inconsistency here: some metadata has been padded while some has not. However, I think the correct fix would be to pad the non-padded attributes for DP attention, rather than using the non-padded version.

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Apr 2, 2026

init_forward_metadata runs before prepare_mlp_sync_batch inflates batch_size. The padding is only needed for mlp collective communication (all-gather/reduce-scatter), not for attention. Each DP rank computes attention independently on its own real requests. So metadata tensors (page_table, cu_seqlens) should always reflect the real batch size, not the padded one.

Other attention backends derive batch size from metadata tensor shapes, while trtllm-mha backend read forward_batch.batch_size directly in forward_extend, which by that point has already been inflated.

So I think we can use cu_seqlens shape -1 to get the batch size directly in forward_extend. That should be consistent in the kernel.

Use cu_seqlens_q.shape[0] - 1 to get the real batch size in
forward_extend, consistent with how other attention backends work.
This removes the need for a separate batch_size field on
TRTLLMMHAMetadata.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Kangyan-Zhou Kangyan-Zhou force-pushed the fix-trtllm-mha-dp-batch-size branch from 5324457 to 89fdfee Compare April 3, 2026 19:05
@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 3, 2026

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 3, 2026
@Kangyan-Zhou Kangyan-Zhou merged commit 5dd2c24 into sgl-project:main Apr 5, 2026
67 of 102 checks passed
Kangyan-Zhou added a commit that referenced this pull request Apr 5, 2026
…21649)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…gl-project#21649)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Fridge003 added a commit that referenced this pull request Apr 7, 2026
…21649)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
xiezhq-hermann pushed a commit to antgroup/sglang that referenced this pull request Apr 7, 2026
…gl-project#21649)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants