Skip to content

[DTensor][Export] Supporting exporting a model with DTensor params/inputs#163609

Closed
SherlockNoMad wants to merge 15 commits intomainfrom
bahuang/export_dtensor
Closed

[DTensor][Export] Supporting exporting a model with DTensor params/inputs#163609
SherlockNoMad wants to merge 15 commits intomainfrom
bahuang/export_dtensor

Conversation

@SherlockNoMad
Copy link
Contributor

@SherlockNoMad SherlockNoMad commented Sep 23, 2025

I experimented with 3 paths to get joint graph for DTensorized module and input

  1. strict_export + aot_export_joint_with_descriptors
  2. graph_capture + aot_export_joint_with_descriptors
  3. aot_export_joint_with_descriptors alone

Added test to guard them.

1 doesn't work, as bw graph region is missing from the joint graph.
I am leaning towards making 2 the recommended path.
If 2 doesn't work going forward, we can fallback to 3.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163609

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 949ea83 with merge base 8701f18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor release notes: export labels Sep 23, 2025
@ezyang
Copy link
Contributor

ezyang commented Sep 24, 2025

The interactions with subsystems I'm not familiar with are a ? to me, but the parts I do understand look fine

@SherlockNoMad SherlockNoMad marked this pull request as ready for review September 27, 2025 00:06
@SherlockNoMad SherlockNoMad added the topic: not user facing topic category label Sep 27, 2025
@SherlockNoMad SherlockNoMad changed the title [rfc] Supporting exporting a model with DTensor params/inputs [DTensor][Export] Supporting exporting a model with DTensor params/inputs Sep 27, 2025

def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# install_free_tensors is required for dynamo to work
with torch._dynamo.config.patch(install_free_tensors=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to turn on inline_inbuilt_nn_modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's passing strict export without this patch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to turn this flag on very soon (#163921) and it could fix various silent bugs so i think we should always have it on if the downstream users end up using this API because it might take some reverts/time to properly land the mentioned PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, added.

suo and others added 11 commits September 29, 2025 15:19
As title, one issue was that our fake mode detection didn't understand
dtensor.

RFC because:
- I'm a dtensor noob so I don't know if this is the right way to use dtensor
- I don't like making torch/_guards.py aware of DTensor, looking for
  suggestions on alternative ways to structure the code.
pytree.register_constant(DTensorSpec)

# TODO: Having DTensorSpec in pytree causes issue with tensor_parallel_transformation
# Need to understand the interaction here
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugsbayasgalan

Some thing weird is going on here.

Also, I don't need torch.utils._pytree.register_constant(DTensorSpec) to make _dynamo_graph_capture_for_export passing.

It's just needed for strict export.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, I reverted pytree.register_constant(DTensorSpec).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh interesting ok

@SherlockNoMad
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 30, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

SherlockNoMad added a commit to pytorch/torchtitan that referenced this pull request Oct 28, 2025
)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
@github-actions github-actions bot deleted the bahuang/export_dtensor branch October 31, 2025 02:17
jquesnelle pushed a commit to NousResearch/torchtitan that referenced this pull request Nov 10, 2025
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner

Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.

Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
pytorchmergebot pushed a commit that referenced this pull request Mar 4, 2026
register_constant(DTensorSpec) in the export test helper was
permanently modifying global pytree state, causing subsequent
compiled DTensor tests to fail. With DTensorSpec registered as a
pytree constant, dynamo no longer decomposes glu into simpler ops
that have sharding strategies, so aten.glu.default gets passed
through to DTensor dispatch which can't handle it.

Wrap in try/finally to deregister after use.

Introduced in PR #163609

Authored with Claude.
Pull Request resolved: #176128
Approved by: https://github.com/SherlockNoMad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: export topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants