[RL][BE] Reapply Compilation in RL#2718
Conversation
e0ec316 to
8984104
Compare
8984104 to
c2c1eab
Compare
| ): | ||
| raise ValueError( | ||
| f"cudagraph_mode='{self.cudagraph_mode}' requires piecewise graph " | ||
| f"cudagraph_mode={self.cudagraph_mode!r} requires piecewise graph " |
There was a problem hiding this comment.
Err I don't believe so - !r gets the repr of the enum right?
| # ), | ||
| compile=GeneratorCompileConfig(backend="none", cudagraph_mode="none"), | ||
| compile=GeneratorCompileConfig( | ||
| backend="eager", cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE |
There was a problem hiding this comment.
curious how to decide to use PIECEWISE or FULL_AND_PIECEWISE?
There was a problem hiding this comment.
The difference is around full cudagraph capturability (related to custom attention implementation or not) and memory usage; fullgraph will use more memory so if we don't see much speedup probably better to just default to piecewise. Let me make this consistent here
| NOTE: inductor will offer the best performance, but will impact numerics - use eager for | ||
| bitwise identical results.""" | ||
|
|
||
| cudagraph_mode: Literal[ |
There was a problem hiding this comment.
Fwiw torchtitan's style is using Literal than Enum, so I'd prefer we consistently work with Literal if possible.
Does tyro work with Enum?
There was a problem hiding this comment.
I can change this back then... I thought enum would be preferred since this is already what is used by vLLM
There was a problem hiding this comment.
If this is only used for generator, that's fine, but the general rule should be "whenever torchtitan and vllm have style conflict, bias toward torchtitan".
If we are going to enable these mode for trainer as well, then I strongly think we should consistently use Literal.
There was a problem hiding this comment.
Hmm since tyro works with Enum, could we actually keep the enum here? I think for interoperability this may be better as if we have some change in vLLM it is downstreamed here for free
There was a problem hiding this comment.
The trainer callout is a good consideration - let's revert back to literal then
|
@claude does tyro work with enums here? |
|
Claude finished @Lucaskabela's task in 6m 1s —— View job Does
|
c2c1eab to
173580c
Compare
| NOTE: inductor will offer the best performance, but will impact numerics - use eager for | ||
| bitwise identical results.""" | ||
|
|
||
| cudagraph_mode: Literal[ |
There was a problem hiding this comment.
If this is only used for generator, that's fine, but the general rule should be "whenever torchtitan and vllm have style conflict, bias toward torchtitan".
If we are going to enable these mode for trainer as well, then I strongly think we should consistently use Literal.
| Piecewise capture supports dynamic sizes and splits cudagraphs around non capturable | ||
| ops like attention | ||
| Full capture captures one graph at the expense of less dynamism and requires full | ||
| capturability |
There was a problem hiding this comment.
should make these consistent
173580c to
24bd0af
Compare
| tensor_parallel_degree=2, | ||
| ), | ||
| # compile=CompileConfig(enable=True, backend="aot_eager"), | ||
| compile=CompileConfig(enable=True, backend="aot_eager"), |
There was a problem hiding this comment.
Nit: why trainer side has a enable field but the generator side don't have it in CompileConfig? seems not very symmetrical
24bd0af to
33f00f6
Compare
## Summary This PR 1) Reapplies pytorch#2710 ## Test plan PREREQ: ensure pytorch/pytorch#178210 is in your torch version ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps) ```bash torchrun --nproc_per_node=2 \ torchtitan/experiments/rl/tests/test_bitwise_identity.py ``` For numerics results in: ``` Trainer computed 30 token log-probs vLLM log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : True Tokens checked : 30 Tokens different : 0 Max delta : 0.000000e+00 Avg delta : 0.000000e+00 Diff mean : 0.000000e+00 Diff max : 0.000000e+00 ============================================================ PASS: vLLM and trainer log-probs are bitwise identical. /home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10 ```
## Summary This PR 1) Reapplies #2710 ## Test plan PREREQ: ensure pytorch/pytorch#178210 is in your torch version ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps) ```bash torchrun --nproc_per_node=2 \ torchtitan/experiments/rl/tests/test_bitwise_identity.py ``` For numerics results in: ``` Trainer computed 30 token log-probs vLLM log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : True Tokens checked : 30 Tokens different : 0 Max delta : 0.000000e+00 Avg delta : 0.000000e+00 Diff mean : 0.000000e+00 Diff max : 0.000000e+00 ============================================================ PASS: vLLM and trainer log-probs are bitwise identical. /home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10 ```
## Summary This PR 1) Reapplies pytorch#2710 ## Test plan PREREQ: ensure pytorch/pytorch#178210 is in your torch version ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps) ```bash torchrun --nproc_per_node=2 \ torchtitan/experiments/rl/tests/test_bitwise_identity.py ``` For numerics results in: ``` Trainer computed 30 token log-probs vLLM log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : True Tokens checked : 30 Tokens different : 0 Max delta : 0.000000e+00 Avg delta : 0.000000e+00 Diff mean : 0.000000e+00 Diff max : 0.000000e+00 ============================================================ PASS: vLLM and trainer log-probs are bitwise identical. /home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10 ```
## Summary This PR 1) Reapplies pytorch#2710 ## Test plan PREREQ: ensure pytorch/pytorch#178210 is in your torch version ```bash python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps) ```bash torchrun --nproc_per_node=2 \ torchtitan/experiments/rl/tests/test_bitwise_identity.py ``` For numerics results in: ``` Trainer computed 30 token log-probs vLLM log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585] ============================================================ LOGPROB COMPARISON RESULTS ============================================================ Bitwise identical : True Tokens checked : 30 Tokens different : 0 Max delta : 0.000000e+00 Avg delta : 0.000000e+00 Diff mean : 0.000000e+00 Diff max : 0.000000e+00 ============================================================ PASS: vLLM and trainer log-probs are bitwise identical. /home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10 ```
Summary
This PR
Test plan
PREREQ: ensure pytorch/pytorch#178210 is in your torch version
With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps)
torchrun --nproc_per_node=2 \ torchtitan/experiments/rl/tests/test_bitwise_identity.pyFor numerics results in: