[RL] adopt local map attention for vLLM attention #2638
Conversation
29220ee to
540e0b3
Compare
| output_flat = output_flat.narrow(0, 0, batch_size * seq_len) | ||
|
|
||
| # Reshape back to titan: (batch, num_heads_local, seq_len, head_dim) | ||
| # Reshape back to titan: (batch, seq_len, num_heads_local, head_dim) |
There was a problem hiding this comment.
Note the transpose (1, 2) - comment is right since we swap num_heads and seq_len on L273
There was a problem hiding this comment.
yup, I thought the shape annotation is only for the next line. Reverted it back!
dad2b8a to
7cad1e9
Compare
| # supporting paged attention / kv cache. | ||
| if batch_invariant_mode: | ||
| replace_with_vllm_compatible_flash_attention( | ||
| replace_with_vllm_attention( |
There was a problem hiding this comment.
I don't think we need this change to make trainer and generator bit-wise identity. To make bit-wise identity, we need the forward path run the same kernels on trainer and generator. This function is replace to vllm.Attention() for kv cache capability. For trainer, set config to Varlen attention should be enough
Also this might not be trainable as vllm.Attention should not have backward?
Let's remove this change from this PR.
There was a problem hiding this comment.
oh I though vllm.Attention which uses PyTorchFlashAttentionImpl has the backward?
also tried removing this, doesn't seem to change numerics
============================================================
LOGPROB COMPARISON RESULTS
============================================================
Bitwise identical : False
Tokens checked : 30
Tokens different : 30
Max delta : 1.041739e-01
Avg delta : 1.816580e-02
Diff mean : 2.139650e-03
Diff max : 1.041739e-01
============================================================
There was a problem hiding this comment.
oh I though vllm.Attention which uses
PyTorchFlashAttentionImplhas the backward?
I see, you are right. But in trainer, we don't need the kv cache capability so we can directly using Varlen attention to achieve bit-wise identity
| tensor_parallel_size=gen_config.parallelism.tensor_parallel_degree, | ||
| distributed_executor_backend="external_launcher", | ||
| gpu_memory_utilization=gen_config.gpu_memory_limit, | ||
| enforce_eager=gen_config.compile.is_eager, |
There was a problem hiding this comment.
Why adding this? Do you need enforce_eager = True when compile and cudagraph are disabled?
There was a problem hiding this comment.
it should use the config from _test_config below, otherwise there's two compile config floating around
391bdda to
da24e6c
Compare
| Returns: | ||
| ``(batch, num_heads, seq_len, head_dim)`` | ||
| """ | ||
| # Capture the original symbolic seq_len from the input BEFORE |
There was a problem hiding this comment.
Actually I don't know why we were using global seq_len here.
- In TP, qkv are sharded on num_head dimension so seq_len should be the same on DTensor / local tensor
- In all-gather CP, qkv are sharded on seq_len dimension before entering CP, but replicate on kv after the hooks (it may not work with varlen attention yet).
In either case
q = q.transpose(1, 2).reshape(batch_size * seq_len, -1, head_dim) should work with local qkv's seq_len?
There was a problem hiding this comment.
It's because L270
# vLLM's flash attention backend may pad the token count (e.g. # round up to an even number), which introduces a new symbolic # shape under torch.compile. Narrow to trim this padding # NOTE: this error only happens when batch_size and seq_len are 1 # which happens with cudagraph capture for dummy input
during cudagraph capture with TP=2, the seqlen should be 1, but it's padded to 2, and we need to capture the original symbolic seq_len from the input BEFORE to_local. After using local map, this entire forward is wrapped in local region, therefore I moved this chunk of code before calling forward.
There was a problem hiding this comment.
this needs follow up with pytorch/pytorch#175690; @Lucaskabela mentioned:
I think this is a symbolic shape propagation error - w.r.t to_local(), the symbol we are using has to be divisible by the TP size, so I think it is something like ( 2*(s77 + 1) // 2) or something odd like that - this actually results in some bug I believe in our code that is generated
The approach in this PR is just workaround like before. Can delete once the fix is landed.
|
running https://github.com/pytorch/torchtitan/pull/2300/changes#diff-61d04a6e7103debe722a05fcf985c39b2471c332de71b93de4e23e878b9a5d5a the test_numerics.py, all three tests passed. The command I am using is cc @wwwjn It means the whole attention module (not only |
fa78caf to
d976e7c
Compare
Lucaskabela
left a comment
There was a problem hiding this comment.
PR LGTM but may need a rebase to avoid merge conflicts
| compute_token_log_probs, | ||
| verify_logprob_identity, | ||
| ) | ||
| from torchtitan.experiments.rl.types import Episode |
There was a problem hiding this comment.
I think this is removed on main if you rebase :)
6b373fd to
d36ea94
Compare
| # Therefore it is breaking compile. We need to fix this in pytorch. | ||
| # See more details in https://github.com/pytorch/pytorch/issues/175690 | ||
| # TODO(@Lucaskabela): remove this once the issue is fixed in pytorch | ||
| batch_size, _, seq_len, head_dim = q.shape |
There was a problem hiding this comment.
@Lucaskabela I updated this to still capture seq_len in local region. This captured seq_len is wrong and will break compile behavior, but it is the right thing to do to get seq_len in local region instead of global.
Please help take a look into the fix for compile. Thanks!!
e785d79 to
512e404
Compare
512e404 to
ba38d7f
Compare
## Summary We turned the compile for generator off in #2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
**Summary** - Adopt LocalMapAttention as the base class for VLLMAttention, replacing manual DTensor.to_local() / DTensor.from_local() with local_map for DTensor-to-local conversion. - Add __call__ override that captures seq_len = q.size(2) from the DTensor before local_map's to_local() and passes it to forward() via kwargs. This preserves the canonical symbolic shape (s72) that GQAttention uses in its downstream view(bs, seqlen, -1). Capturing from the DTensor's global shape ensures the correct symbolic size is used under torch.compile. See pytorch/pytorch#175690 - Remove replace_with_vllm_compatible_flash_attention() and its usage in the trainer — the trainer no longer patches its attention module to match vLLM's kernel. - Fix the test's vLLM engine creation to use GeneratorCompileConfig instead of hardcoded CompilationConfig, aligning the test with the generator's actual compile/CUDA-graph settings. **Test attention numerics under both true eager mode and compile mode:** ``` NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2 torchtitan/experiments/rl/tests/test_attn_numerics.py ``` > ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : False Tokens checked : 30 Tokens different : 30 Max delta : 1.041739e-01 Avg delta : 1.816580e-02 Diff mean : 2.139650e-03 Diff max : 1.041739e-01 ============================================================
## Summary We turned the compile for generator off in pytorch#2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
## Summary We turned the compile for generator off in pytorch#2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
## Summary We turned the compile for generator off in #2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
**Summary** - Adopt LocalMapAttention as the base class for VLLMAttention, replacing manual DTensor.to_local() / DTensor.from_local() with local_map for DTensor-to-local conversion. - Add __call__ override that captures seq_len = q.size(2) from the DTensor before local_map's to_local() and passes it to forward() via kwargs. This preserves the canonical symbolic shape (s72) that GQAttention uses in its downstream view(bs, seqlen, -1). Capturing from the DTensor's global shape ensures the correct symbolic size is used under torch.compile. See pytorch/pytorch#175690 - Remove replace_with_vllm_compatible_flash_attention() and its usage in the trainer — the trainer no longer patches its attention module to match vLLM's kernel. - Fix the test's vLLM engine creation to use GeneratorCompileConfig instead of hardcoded CompilationConfig, aligning the test with the generator's actual compile/CUDA-graph settings. **Test attention numerics under both true eager mode and compile mode:** ``` NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2 torchtitan/experiments/rl/tests/test_attn_numerics.py ``` > ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : False Tokens checked : 30 Tokens different : 30 Max delta : 1.041739e-01 Avg delta : 1.816580e-02 Diff mean : 2.139650e-03 Diff max : 1.041739e-01 ============================================================
## Summary We turned the compile for generator off in pytorch#2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
**Summary** - Adopt LocalMapAttention as the base class for VLLMAttention, replacing manual DTensor.to_local() / DTensor.from_local() with local_map for DTensor-to-local conversion. - Add __call__ override that captures seq_len = q.size(2) from the DTensor before local_map's to_local() and passes it to forward() via kwargs. This preserves the canonical symbolic shape (s72) that GQAttention uses in its downstream view(bs, seqlen, -1). Capturing from the DTensor's global shape ensures the correct symbolic size is used under torch.compile. See pytorch/pytorch#175690 - Remove replace_with_vllm_compatible_flash_attention() and its usage in the trainer — the trainer no longer patches its attention module to match vLLM's kernel. - Fix the test's vLLM engine creation to use GeneratorCompileConfig instead of hardcoded CompilationConfig, aligning the test with the generator's actual compile/CUDA-graph settings. **Test attention numerics under both true eager mode and compile mode:** ``` NCCL_NVLS_ENABLE=0 torchrun --nproc_per_node=2 torchtitan/experiments/rl/tests/test_attn_numerics.py ``` > ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : False Tokens checked : 30 Tokens different : 30 Max delta : 1.041739e-01 Avg delta : 1.816580e-02 Diff mean : 2.139650e-03 Diff max : 1.041739e-01 ============================================================
## Summary We turned the compile for generator off in pytorch#2638 due to conflict with DTensor and symbolic propogation We fix this in pytorch/pytorch#178210 so reenable this config (once landed in nightly) ## Test ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ```
Summary
replacing manual DTensor.to_local() / DTensor.from_local()
with local_map for DTensor-to-local conversion.
the DTensor before local_map's to_local() and passes it to
forward() via kwargs. This preserves the canonical symbolic
shape (s72) that GQAttention uses in its downstream view(bs,
seqlen, -1). Capturing from the DTensor's global shape ensures
the correct symbolic size is used under torch.compile.
See Dtensor Shard->Replicate redistribution corrupts symbolic shapes under torch.compile pytorch#175690
its usage in the trainer — the trainer no longer patches its
attention module to match vLLM's kernel.
GeneratorCompileConfig instead of hardcoded CompilationConfig,
aligning the test with the generator's actual
compile/CUDA-graph settings.
Test attention numerics under both true eager mode and compile mode: