Skip to content

[Compiler Toolkit] JointGraph-based Training Prototype for llama3#1794

Merged
SherlockNoMad merged 19 commits intomainfrom
joint_graph_runner
Oct 28, 2025
Merged

[Compiler Toolkit] JointGraph-based Training Prototype for llama3#1794
SherlockNoMad merged 19 commits intomainfrom
joint_graph_runner

Conversation

@SherlockNoMad
Copy link
Contributor

@SherlockNoMad SherlockNoMad commented Oct 3, 2025

This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP

Nest Step

  • Enable CudaGraph
  • Enable SimpleFSDP + EP
  • Showcase user annotation on MoE for dispatch, compute, combine region
  • Enable PP with custom Runner

Issues

Repro steps:
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
@SherlockNoMad SherlockNoMad changed the title Joint Graph Runner JointGraph-based Training Prototype Oct 3, 2025
@SherlockNoMad SherlockNoMad marked this pull request as ready for review October 9, 2025 00:31
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for exploration purpose? If so I'd suggest we work in a branch / fork.

@ezyang
Copy link
Contributor

ezyang commented Oct 15, 2025

cc @bobrenjc93 @aorenste this might be a good way to look at compile on one rank, perhaps??

@SherlockNoMad
Copy link
Contributor Author

cc @bobrenjc93 @aorenste this might be a good way to look at compile on one rank, perhaps??

Yes, I have examples on how graphs are different on different ranks.
We can investigate on how to parameterize them.

P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

@SherlockNoMad SherlockNoMad changed the title JointGraph-based Training Prototype [Compiler Toolkit] JointGraph-based Training Prototype for llama3 Oct 27, 2025
**SimpleFSDP + TP + EP**
```shell
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --compile.enable --compile.backend "aot_eager" --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SherlockNoMad I think we should remove --compile.enable as well as --compile.backend here.
Since we are applying the customized CompiledModule, we don't really want torch.compile() the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@SherlockNoMad SherlockNoMad merged commit 06ec495 into main Oct 28, 2025
4 checks passed
jquesnelle pushed a commit to NousResearch/torchtitan that referenced this pull request Nov 10, 2025
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner

Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.

Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants