Skip to content

[annotate] Annotation should be mapped across submod#165202

Closed
yushangdi wants to merge 2 commits intomainfrom
annotation_submod
Closed

[annotate] Annotation should be mapped across submod#165202
yushangdi wants to merge 2 commits intomainfrom
annotation_submod

Conversation

@yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Oct 10, 2025

The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if mask_mod has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165202

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 676774b with merge base 37d57ac (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yushangdi yushangdi added the release notes: fx release notes category label Oct 11, 2025
@yushangdi yushangdi changed the title annotation should be mapped across submod [annotate] Annotation should be mapped across submod Oct 11, 2025
@yushangdi yushangdi marked this pull request as ready for review October 11, 2025 01:23
@yushangdi yushangdi requested a review from bdhirsh as a code owner October 11, 2025 01:23
Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm with minor comments.
get_custom_metadata doesn't seems useful to user, let's make it private.

@yushangdi yushangdi force-pushed the annotation_submod branch 2 times, most recently from a0c8704 to 431e6b8 Compare October 13, 2025 16:53
@yushangdi
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch rebase origin/main returned non-zero exit code 1

Rebasing (1/1)
Auto-merging test/dynamo/test_fx_annotate.py
CONFLICT (content): Merge conflict in test/dynamo/test_fx_annotate.py
error: could not apply d018c77aadd... [annotate] Annotation should be mapped across submod (#165202)
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply d018c77aadd... # [annotate] Annotation should be mapped across submod (#165202)
Details for Dev Infra team Raised by workflow job

@yushangdi
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: pytorch#165202
Approved by: https://github.com/SherlockNoMad
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: pytorch#165202
Approved by: https://github.com/SherlockNoMad
SherlockNoMad added a commit to pytorch/torchtitan that referenced this pull request Oct 28, 2025
)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
jquesnelle pushed a commit to NousResearch/torchtitan that referenced this pull request Nov 10, 2025
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner

Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.

Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
@github-actions github-actions bot deleted the annotation_submod branch November 14, 2025 02:17
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants