Skip to content

[RL][BE] Reapply Compilation in RL#2718

Merged
Lucaskabela merged 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/use_cudagraph_enum
Mar 27, 2026
Merged

[RL][BE] Reapply Compilation in RL#2718
Lucaskabela merged 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/use_cudagraph_enum

Conversation

@Lucaskabela
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela commented Mar 26, 2026

Summary

This PR

  1. Reapplies [RL] Turn compile for generator back on #2710

Test plan

PREREQ: ensure pytorch/pytorch#178210 is in your torch version

python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B

With both compiles on, we expect a 4x speedup over eager (timed from ~400s e2e to ~100s for 10 steps)

torchrun --nproc_per_node=2 \
      torchtitan/experiments/rl/tests/test_bitwise_identity.py

For numerics results in:

Trainer computed 30 token log-probs
  vLLM   log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
  Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : True
  Tokens checked    : 30
  Tokens different  : 0
  Max delta         : 0.000000e+00
  Avg delta         : 0.000000e+00
  Diff mean         : 0.000000e+00
  Diff max          : 0.000000e+00
============================================================
PASS: vLLM and trainer log-probs are bitwise identical.
/home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 26, 2026
@Lucaskabela Lucaskabela force-pushed the lucaskabela/use_cudagraph_enum branch from e0ec316 to 8984104 Compare March 26, 2026 17:57
@Lucaskabela Lucaskabela changed the title [BE] Change from string to enum for cudagraph options in vLLM Config [RL][BE] Reapply Compilation in RL and migrate string to enum for Cudagraph Mar 26, 2026
@Lucaskabela Lucaskabela force-pushed the lucaskabela/use_cudagraph_enum branch from 8984104 to c2c1eab Compare March 26, 2026 18:38
@Lucaskabela Lucaskabela requested a review from wwwjn March 26, 2026 18:40
):
raise ValueError(
f"cudagraph_mode='{self.cudagraph_mode}' requires piecewise graph "
f"cudagraph_mode={self.cudagraph_mode!r} requires piecewise graph "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err I don't believe so - !r gets the repr of the enum right?

# ),
compile=GeneratorCompileConfig(backend="none", cudagraph_mode="none"),
compile=GeneratorCompileConfig(
backend="eager", cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious how to decide to use PIECEWISE or FULL_AND_PIECEWISE?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is around full cudagraph capturability (related to custom attention implementation or not) and memory usage; fullgraph will use more memory so if we don't see much speedup probably better to just default to piecewise. Let me make this consistent here

NOTE: inductor will offer the best performance, but will impact numerics - use eager for
bitwise identical results."""

cudagraph_mode: Literal[
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw torchtitan's style is using Literal than Enum, so I'd prefer we consistently work with Literal if possible.

Does tyro work with Enum?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change this back then... I thought enum would be preferred since this is already what is used by vLLM

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is only used for generator, that's fine, but the general rule should be "whenever torchtitan and vllm have style conflict, bias toward torchtitan".

If we are going to enable these mode for trainer as well, then I strongly think we should consistently use Literal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm since tyro works with Enum, could we actually keep the enum here? I think for interoperability this may be better as if we have some change in vLLM it is downstreamed here for free

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer callout is a good consideration - let's revert back to literal then

@Lucaskabela
Copy link
Copy Markdown
Contributor Author

@claude does tyro work with enums here?

@claude
Copy link
Copy Markdown

claude Bot commented Mar 26, 2026

Claude finished @Lucaskabela's task in 6m 1s —— View job


Does tyro work with enums?

  • Investigate enum usage in codebase and tyro compatibility
  • Provide answer

Yes, tyro supports Python enum.Enum types natively. When a dataclass field has an Enum type annotation, tyro automatically generates CLI options using the enum member names (case-insensitive). For example, with cudagraph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, the CLI would accept --generator.compile.cudagraph-mode PIECEWISE (or piecewise, none, full, etc.).

Tyro has had enum support for a long time — it's covered in tyro's enum examples. Since torchtitan requires tyro >= 1.0.5, this is well within the supported range.

However, note tianyu-l's review comment: torchtitan's convention is to use Literal types rather than Enum for config fields. The existing pattern in torchtitan/config/configs.py consistently uses Literal (e.g., Literal["bfloat16", "float32"], Literal["selective", "full", "memory_budget", "none"]). Even when an Enum exists internally (like AsyncMode in torchtitan/components/checkpoint.py), the config field uses Literal and converts to the enum at runtime.

So while tyro can handle the CUDAGraphMode enum, the question is whether it should be used directly in the config dataclass given torchtitan's style preference for Literal. The trade-off:

  • Literal approach: Consistent with the rest of torchtitan's config system, no external type dependency in the config dataclass, straightforward string values on the CLI.
  • Enum approach: Stronger type safety, matches vLLM's API directly (no string conversion needed), but introduces a vLLM import into the dataclass definition.

Since this is in torchtitan/experiments/rl/ which already has a hard dependency on vLLM, the external import concern is less of an issue. But stylistic consistency with the rest of the codebase is still worth considering per tianyu-l's feedback.

@Lucaskabela Lucaskabela requested a review from tianyu-l March 26, 2026 21:31
@Lucaskabela Lucaskabela force-pushed the lucaskabela/use_cudagraph_enum branch from c2c1eab to 173580c Compare March 26, 2026 21:34
NOTE: inductor will offer the best performance, but will impact numerics - use eager for
bitwise identical results."""

cudagraph_mode: Literal[
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is only used for generator, that's fine, but the general rule should be "whenever torchtitan and vllm have style conflict, bias toward torchtitan".

If we are going to enable these mode for trainer as well, then I strongly think we should consistently use Literal.

Comment on lines 47 to 50
Piecewise capture supports dynamic sizes and splits cudagraphs around non capturable
ops like attention
Full capture captures one graph at the expense of less dynamism and requires full
capturability
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should make these consistent

@Lucaskabela Lucaskabela force-pushed the lucaskabela/use_cudagraph_enum branch from 173580c to 24bd0af Compare March 26, 2026 21:53
@Lucaskabela Lucaskabela changed the title [RL][BE] Reapply Compilation in RL and migrate string to enum for Cudagraph [RL][BE] Reapply Compilation in RL Mar 26, 2026
Copy link
Copy Markdown
Contributor

@acisseJZhong acisseJZhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

tensor_parallel_degree=2,
),
# compile=CompileConfig(enable=True, backend="aot_eager"),
compile=CompileConfig(enable=True, backend="aot_eager"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why trainer side has a enable field but the generator side don't have it in CompileConfig? seems not very symmetrical

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lucaskabela Lucaskabela force-pushed the lucaskabela/use_cudagraph_enum branch from 24bd0af to 33f00f6 Compare March 27, 2026 12:35
@Lucaskabela Lucaskabela merged commit cd9bee1 into pytorch:main Mar 27, 2026
13 of 22 checks passed
chelsea0x3b pushed a commit to chelsea0x3b/torchtitan that referenced this pull request Mar 30, 2026
## Summary
This PR 
1) Reapplies pytorch#2710

## Test plan
PREREQ: ensure pytorch/pytorch#178210 is in your
torch version
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
With both compiles on, we expect a 4x speedup over eager (timed from
~400s e2e to ~100s for 10 steps)

```bash
torchrun --nproc_per_node=2 \
      torchtitan/experiments/rl/tests/test_bitwise_identity.py
```

For numerics results in:
```
Trainer computed 30 token log-probs
  vLLM   log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
  Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : True
  Tokens checked    : 30
  Tokens different  : 0
  Max delta         : 0.000000e+00
  Avg delta         : 0.000000e+00
  Diff mean         : 0.000000e+00
  Diff max          : 0.000000e+00
============================================================
PASS: vLLM and trainer log-probs are bitwise identical.
/home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10
```
acisseJZhong pushed a commit that referenced this pull request Mar 31, 2026
## Summary
This PR 
1) Reapplies #2710

## Test plan
PREREQ: ensure pytorch/pytorch#178210 is in your
torch version
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
With both compiles on, we expect a 4x speedup over eager (timed from
~400s e2e to ~100s for 10 steps)

```bash
torchrun --nproc_per_node=2 \
      torchtitan/experiments/rl/tests/test_bitwise_identity.py
```

For numerics results in:
```
Trainer computed 30 token log-probs
  vLLM   log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
  Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : True
  Tokens checked    : 30
  Tokens different  : 0
  Max delta         : 0.000000e+00
  Avg delta         : 0.000000e+00
  Diff mean         : 0.000000e+00
  Diff max          : 0.000000e+00
============================================================
PASS: vLLM and trainer log-probs are bitwise identical.
/home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10
```
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
## Summary
This PR 
1) Reapplies pytorch#2710

## Test plan
PREREQ: ensure pytorch/pytorch#178210 is in your
torch version
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
With both compiles on, we expect a 4x speedup over eager (timed from
~400s e2e to ~100s for 10 steps)

```bash
torchrun --nproc_per_node=2 \
      torchtitan/experiments/rl/tests/test_bitwise_identity.py
```

For numerics results in:
```
Trainer computed 30 token log-probs
  vLLM   log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
  Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : True
  Tokens checked    : 30
  Tokens different  : 0
  Max delta         : 0.000000e+00
  Avg delta         : 0.000000e+00
  Diff mean         : 0.000000e+00
  Diff max          : 0.000000e+00
============================================================
PASS: vLLM and trainer log-probs are bitwise identical.
/home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10
```
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
## Summary
This PR 
1) Reapplies pytorch#2710

## Test plan
PREREQ: ensure pytorch/pytorch#178210 is in your
torch version
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b --hf_assets_path=torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B
```
With both compiles on, we expect a 4x speedup over eager (timed from
~400s e2e to ~100s for 10 steps)

```bash
torchrun --nproc_per_node=2 \
      torchtitan/experiments/rl/tests/test_bitwise_identity.py
```

For numerics results in:
```
Trainer computed 30 token log-probs
  vLLM   log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
  Trainer log-probs[:5]: [-4.129570484161377, -1.795021891593933, -0.71578049659729, -0.2110116183757782, -0.9374725222587585]
============================================================
LOGPROB COMPARISON RESULTS
============================================================
  Bitwise identical : True
  Tokens checked    : 30
  Tokens different  : 0
  Max delta         : 0.000000e+00
  Avg delta         : 0.000000e+00
  Diff mean         : 0.000000e+00
  Diff max          : 0.000000e+00
============================================================
PASS: vLLM and trainer log-probs are bitwise identical.
/home/lucaskabela/.conda/envs/pytorch_build/lib/python3.10
```
@Lucaskabela Lucaskabela deleted the lucaskabela/use_cudagraph_enum branch May 6, 2026 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants