Skip to content

[RL] adopt local map attention for vLLM attention #2638

Merged
acisseJZhong merged 5 commits intomainfrom
rl_attention
Mar 23, 2026
Merged

[RL] adopt local map attention for vLLM attention #2638
acisseJZhong merged 5 commits intomainfrom
rl_attention

Conversation

@acisseJZhong
Copy link
Copy Markdown
Contributor

@acisseJZhong acisseJZhong commented Mar 20, 2026

Summary

  • Adopt LocalMapAttention as the base class for VLLMAttention,
    replacing manual DTensor.to_local() / DTensor.from_local()
    with local_map for DTensor-to-local conversion.
  • Add call override that captures seq_len = q.size(2) from
    the DTensor before local_map's to_local() and passes it to
    forward() via kwargs. This preserves the canonical symbolic
    shape (s72) that GQAttention uses in its downstream view(bs,
    seqlen, -1). Capturing from the DTensor's global shape ensures
    the correct symbolic size is used under torch.compile.
    See Dtensor Shard->Replicate redistribution corrupts symbolic shapes under torch.compile pytorch#175690
  • Remove replace_with_vllm_compatible_flash_attention() and
    its usage in the trainer — the trainer no longer patches its
    attention module to match vLLM's kernel.
  • Fix the test's vLLM engine creation to use
    GeneratorCompileConfig instead of hardcoded CompilationConfig,
    aligning the test with the generator's actual
    compile/CUDA-graph settings.

Test attention numerics under both true eager mode and compile mode:

NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2     torchtitan/experiments/rl/tests/test_attn_numerics.py

============================================================
LOGPROB COMPARISON RESULTS
============================================================
Bitwise identical : False
Tokens checked : 30
Tokens different : 30
Max delta : 1.041739e-01
Avg delta : 1.816580e-02
Diff mean : 2.139650e-03
Diff max : 1.041739e-01
============================================================

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 20, 2026
@acisseJZhong acisseJZhong marked this pull request as draft March 20, 2026 18:30
output_flat = output_flat.narrow(0, 0, batch_size * seq_len)

# Reshape back to titan: (batch, num_heads_local, seq_len, head_dim)
# Reshape back to titan: (batch, seq_len, num_heads_local, head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the transpose (1, 2) - comment is right since we swap num_heads and seq_len on L273

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, I thought the shape annotation is only for the next line. Reverted it back!

@acisseJZhong acisseJZhong marked this pull request as ready for review March 21, 2026 00:45
# supporting paged attention / kv cache.
if batch_invariant_mode:
replace_with_vllm_compatible_flash_attention(
replace_with_vllm_attention(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this change to make trainer and generator bit-wise identity. To make bit-wise identity, we need the forward path run the same kernels on trainer and generator. This function is replace to vllm.Attention() for kv cache capability. For trainer, set config to Varlen attention should be enough

Also this might not be trainable as vllm.Attention should not have backward?

Let's remove this change from this PR.

Copy link
Copy Markdown
Contributor Author

@acisseJZhong acisseJZhong Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I though vllm.Attention which uses PyTorchFlashAttentionImpl has the backward?

also tried removing this, doesn't seem to change numerics

============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : False
  Tokens checked    : 30
  Tokens different  : 30
  Max delta         : 1.041739e-01
  Avg delta         : 1.816580e-02
  Diff mean         : 2.139650e-03
  Diff max          : 1.041739e-01
============================================================

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I though vllm.Attention which uses PyTorchFlashAttentionImpl has the backward?

I see, you are right. But in trainer, we don't need the kv cache capability so we can directly using Varlen attention to achieve bit-wise identity

tensor_parallel_size=gen_config.parallelism.tensor_parallel_degree,
distributed_executor_backend="external_launcher",
gpu_memory_utilization=gen_config.gpu_memory_limit,
enforce_eager=gen_config.compile.is_eager,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this? Do you need enforce_eager = True when compile and cudagraph are disabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should use the config from _test_config below, otherwise there's two compile config floating around

@acisseJZhong acisseJZhong force-pushed the rl_attention branch 2 times, most recently from 391bdda to da24e6c Compare March 21, 2026 07:08
Returns:
``(batch, num_heads, seq_len, head_dim)``
"""
# Capture the original symbolic seq_len from the input BEFORE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't know why we were using global seq_len here.

  • In TP, qkv are sharded on num_head dimension so seq_len should be the same on DTensor / local tensor
  • In all-gather CP, qkv are sharded on seq_len dimension before entering CP, but replicate on kv after the hooks (it may not work with varlen attention yet).

In either case
q = q.transpose(1, 2).reshape(batch_size * seq_len, -1, head_dim) should work with local qkv's seq_len?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because L270

    # vLLM's flash attention backend may pad the token count (e.g.
    # round up to an even number), which introduces a new symbolic
    # shape under torch.compile.  Narrow to trim this padding
    # NOTE: this error only happens when batch_size and seq_len are 1
    # which happens with cudagraph capture for dummy input

during cudagraph capture with TP=2, the seqlen should be 1, but it's padded to 2, and we need to capture the original symbolic seq_len from the input BEFORE to_local. After using local map, this entire forward is wrapped in local region, therefore I moved this chunk of code before calling forward.

Copy link
Copy Markdown
Contributor Author

@acisseJZhong acisseJZhong Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs follow up with pytorch/pytorch#175690; @Lucaskabela mentioned:

I think this is a symbolic shape propagation error - w.r.t to_local(), the symbol we are using has to be divisible by the TP size, so I think it is something like ( 2*(s77 + 1) // 2) or something odd like that - this actually results in some bug I believe in our code that is generated

The approach in this PR is just workaround like before. Can delete once the fix is landed.

@acisseJZhong
Copy link
Copy Markdown
Contributor Author

acisseJZhong commented Mar 23, 2026

running https://github.com/pytorch/torchtitan/pull/2300/changes#diff-61d04a6e7103debe722a05fcf985c39b2471c332de71b93de4e23e878b9a5d5a the test_numerics.py, all three tests passed. The command I am using is

NCCL_NVLS_ENABLE=0 MASTER_ADDR=localhost MASTER_PORT=29500 MODEL_CHECKPOINT_PATH=Qwen/Qwen3-0.6B pytest         torchtitan/experiments/rl/tests/test_numerics.py -v -s

cc @wwwjn

It means the whole attention module (not only self_attn) keep align with vllm's numerics,

@acisseJZhong acisseJZhong requested a review from zhxchen17 March 23, 2026 17:55
@acisseJZhong acisseJZhong force-pushed the rl_attention branch 2 times, most recently from fa78caf to d976e7c Compare March 23, 2026 20:56
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR LGTM but may need a rebase to avoid merge conflicts

compute_token_log_probs,
verify_logprob_identity,
)
from torchtitan.experiments.rl.types import Episode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is removed on main if you rebase :)

# Therefore it is breaking compile. We need to fix this in pytorch.
# See more details in https://github.com/pytorch/pytorch/issues/175690
# TODO(@Lucaskabela): remove this once the issue is fixed in pytorch
batch_size, _, seq_len, head_dim = q.shape
Copy link
Copy Markdown
Contributor Author

@acisseJZhong acisseJZhong Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lucaskabela I updated this to still capture seq_len in local region. This captured seq_len is wrong and will break compile behavior, but it is the right thing to do to get seq_len in local region instead of global.

Please help take a look into the fix for compile. Thanks!!

@acisseJZhong acisseJZhong merged commit d229e97 into main Mar 23, 2026
18 of 33 checks passed
Lucaskabela added a commit that referenced this pull request Mar 26, 2026
## Summary

We turned the compile for generator off in
#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
pytorch-bot Bot pushed a commit that referenced this pull request Mar 27, 2026
**Summary**                                                       
   
  - Adopt LocalMapAttention as the base class for VLLMAttention,
   replacing manual DTensor.to_local() / DTensor.from_local() 
  with local_map for DTensor-to-local conversion. 
  - Add __call__ override that captures seq_len = q.size(2) from
   the DTensor before local_map's to_local() and passes it to   
  forward() via kwargs. This preserves the canonical symbolic   
  shape (s72) that GQAttention uses in its downstream view(bs, 
  seqlen, -1). Capturing from the DTensor's global shape ensures 
  the correct symbolic size is used under torch.compile. 
  See pytorch/pytorch#175690
  - Remove replace_with_vllm_compatible_flash_attention() and 
  its usage in the trainer — the trainer no longer patches its  
  attention module to match vLLM's kernel.
  - Fix the test's vLLM engine creation to use                  
  GeneratorCompileConfig instead of hardcoded CompilationConfig,
   aligning the test with the generator's actual
  compile/CUDA-graph settings. 

**Test attention numerics under both true eager mode and compile mode:**
```
NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2     torchtitan/experiments/rl/tests/test_attn_numerics.py
```

> ============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : False
  Tokens checked    : 30
  Tokens different  : 30
  Max delta         : 1.041739e-01
  Avg delta         : 1.816580e-02
  Diff mean         : 2.139650e-03
  Diff max          : 1.041739e-01
============================================================
weifengpy pushed a commit to weifengpy/torchtitan that referenced this pull request Mar 27, 2026
## Summary

We turned the compile for generator off in
pytorch#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
chelsea0x3b pushed a commit to chelsea0x3b/torchtitan that referenced this pull request Mar 30, 2026
## Summary

We turned the compile for generator off in
pytorch#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
acisseJZhong pushed a commit that referenced this pull request Mar 31, 2026
## Summary

We turned the compile for generator off in
#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
**Summary**                                                       
   
  - Adopt LocalMapAttention as the base class for VLLMAttention,
   replacing manual DTensor.to_local() / DTensor.from_local() 
  with local_map for DTensor-to-local conversion. 
  - Add __call__ override that captures seq_len = q.size(2) from
   the DTensor before local_map's to_local() and passes it to   
  forward() via kwargs. This preserves the canonical symbolic   
  shape (s72) that GQAttention uses in its downstream view(bs, 
  seqlen, -1). Capturing from the DTensor's global shape ensures 
  the correct symbolic size is used under torch.compile. 
  See pytorch/pytorch#175690
  - Remove replace_with_vllm_compatible_flash_attention() and 
  its usage in the trainer — the trainer no longer patches its  
  attention module to match vLLM's kernel.
  - Fix the test's vLLM engine creation to use                  
  GeneratorCompileConfig instead of hardcoded CompilationConfig,
   aligning the test with the generator's actual
  compile/CUDA-graph settings. 

**Test attention numerics under both true eager mode and compile mode:**
```
NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2     torchtitan/experiments/rl/tests/test_attn_numerics.py
```

> ============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : False
  Tokens checked    : 30
  Tokens different  : 30
  Max delta         : 1.041739e-01
  Avg delta         : 1.816580e-02
  Diff mean         : 2.139650e-03
  Diff max          : 1.041739e-01
============================================================
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
## Summary

We turned the compile for generator off in
pytorch#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
**Summary**                                                       
   
  - Adopt LocalMapAttention as the base class for VLLMAttention,
   replacing manual DTensor.to_local() / DTensor.from_local() 
  with local_map for DTensor-to-local conversion. 
  - Add __call__ override that captures seq_len = q.size(2) from
   the DTensor before local_map's to_local() and passes it to   
  forward() via kwargs. This preserves the canonical symbolic   
  shape (s72) that GQAttention uses in its downstream view(bs, 
  seqlen, -1). Capturing from the DTensor's global shape ensures 
  the correct symbolic size is used under torch.compile. 
  See pytorch/pytorch#175690
  - Remove replace_with_vllm_compatible_flash_attention() and 
  its usage in the trainer — the trainer no longer patches its  
  attention module to match vLLM's kernel.
  - Fix the test's vLLM engine creation to use                  
  GeneratorCompileConfig instead of hardcoded CompilationConfig,
   aligning the test with the generator's actual
  compile/CUDA-graph settings. 

**Test attention numerics under both true eager mode and compile mode:**
```
NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2     torchtitan/experiments/rl/tests/test_attn_numerics.py
```

> ============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : False
  Tokens checked    : 30
  Tokens different  : 30
  Max delta         : 1.041739e-01
  Avg delta         : 1.816580e-02
  Diff mean         : 2.139650e-03
  Diff max          : 1.041739e-01
============================================================
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
## Summary

We turned the compile for generator off in
pytorch#2638 due to conflict with
DTensor and symbolic propogation

We fix this in pytorch/pytorch#178210 so
reenable this config (once landed in nightly)

## Test
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants