[DTensor][Export] Supporting exporting a model with DTensor params/inputs#163609
[DTensor][Export] Supporting exporting a model with DTensor params/inputs#163609SherlockNoMad wants to merge 15 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163609
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 949ea83 with merge base 8701f18 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
The interactions with subsystems I'm not familiar with are a ? to me, but the parts I do understand look fine |
bf31cea to
71f25c4
Compare
|
|
||
| def strict_export_and_aot_export_joint_with_descriptors(model, inputs): | ||
| # install_free_tensors is required for dynamo to work | ||
| with torch._dynamo.config.patch(install_free_tensors=True): |
There was a problem hiding this comment.
I think you also need to turn on inline_inbuilt_nn_modules.
There was a problem hiding this comment.
it's passing strict export without this patch
There was a problem hiding this comment.
We want to turn this flag on very soon (#163921) and it could fix various silent bugs so i think we should always have it on if the downstream users end up using this API because it might take some reverts/time to properly land the mentioned PR.
As title, one issue was that our fake mode detection didn't understand dtensor. RFC because: - I'm a dtensor noob so I don't know if this is the right way to use dtensor - I don't like making torch/_guards.py aware of DTensor, looking for suggestions on alternative ways to structure the code.
4ab5216 to
61f71fb
Compare
| pytree.register_constant(DTensorSpec) | ||
|
|
||
| # TODO: Having DTensorSpec in pytree causes issue with tensor_parallel_transformation | ||
| # Need to understand the interaction here |
There was a problem hiding this comment.
Some thing weird is going on here.
Also, I don't need torch.utils._pytree.register_constant(DTensorSpec) to make _dynamo_graph_capture_for_export passing.
It's just needed for strict export.
There was a problem hiding this comment.
alright, I reverted pytree.register_constant(DTensorSpec).
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow. Setup: shard_dp = 2, tp = 4. MVP - [Done] Start with a simpleFSDP model, enable TP + FSDP - [Done] Apply [aot_export_joing_with_descriptor](pytorch/pytorch#163609) on parallelized module with DTensor input to get the joint graph - [Done] Apply min_cut_partitioner to get forward and backward graph module - [Done but Need verification] Apply prefect/bucketing graph passes on fw_gm and bw_gm to reorder/group the communication collectives - [Done] Run the joint graph with `aot_compile_joint_with_descriptors` - [Done] Region Inductor for FlexAttention, need to run on top of pytorch/pytorch#165202 and pytorch/pytorch#164776 Nest Step - Enable CudaGraph - Enable SimpleFSDP + EP - Showcase user annotation on MoE for dispatch, compute, combine region - Enable PP with custom Runner Issues - pytorch/pytorch#164559 - pytorch/pytorch#164543 - What's input order for aot_export_joint graph? using model.parameter() 's order as input seems wrong. Repro steps: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 Run with FlexAttention: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_paral lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn Sample output: P1975157784: rank0_autograd_function_0fea2786.py P1975158481: rank1_autograd_function_28587623.py --------- Co-authored-by: Simon Fan <xmfan@meta.com>
…torch#1794) This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow. Setup: shard_dp = 2, tp = 4. MVP - [Done] Start with a simpleFSDP model, enable TP + FSDP - [Done] Apply [aot_export_joing_with_descriptor](pytorch/pytorch#163609) on parallelized module with DTensor input to get the joint graph - [Done] Apply min_cut_partitioner to get forward and backward graph module - [Done but Need verification] Apply prefect/bucketing graph passes on fw_gm and bw_gm to reorder/group the communication collectives - [Done] Run the joint graph with `aot_compile_joint_with_descriptors` - [Done] Region Inductor for FlexAttention, need to run on top of pytorch/pytorch#165202 and pytorch/pytorch#164776 Nest Step - Enable CudaGraph - Enable SimpleFSDP + EP - Showcase user annotation on MoE for dispatch, compute, combine region - Enable PP with custom Runner Issues - pytorch/pytorch#164559 - pytorch/pytorch#164543 - What's input order for aot_export_joint graph? using model.parameter() 's order as input seems wrong. Repro steps: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 Run with FlexAttention: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_paral lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn Sample output: P1975157784: rank0_autograd_function_0fea2786.py P1975158481: rank1_autograd_function_28587623.py --------- Co-authored-by: Simon Fan <xmfan@meta.com>
…torch#1794) This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow. Setup: shard_dp = 2, tp = 4. MVP - [Done] Start with a simpleFSDP model, enable TP + FSDP - [Done] Apply [aot_export_joing_with_descriptor](pytorch/pytorch#163609) on parallelized module with DTensor input to get the joint graph - [Done] Apply min_cut_partitioner to get forward and backward graph module - [Done but Need verification] Apply prefect/bucketing graph passes on fw_gm and bw_gm to reorder/group the communication collectives - [Done] Run the joint graph with `aot_compile_joint_with_descriptors` - [Done] Region Inductor for FlexAttention, need to run on top of pytorch/pytorch#165202 and pytorch/pytorch#164776 Nest Step - Enable CudaGraph - Enable SimpleFSDP + EP - Showcase user annotation on MoE for dispatch, compute, combine region - Enable PP with custom Runner Issues - pytorch/pytorch#164559 - pytorch/pytorch#164543 - What's input order for aot_export_joint graph? using model.parameter() 's order as input seems wrong. Repro steps: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 Run with FlexAttention: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_paral lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn Sample output: P1975157784: rank0_autograd_function_0fea2786.py P1975158481: rank1_autograd_function_28587623.py --------- Co-authored-by: Simon Fan <xmfan@meta.com>
…torch#1794) This is an e2e prototype to run llama3-simplefsdp using export-y aot_autograd workflow. Setup: shard_dp = 2, tp = 4. MVP - [Done] Start with a simpleFSDP model, enable TP + FSDP - [Done] Apply [aot_export_joing_with_descriptor](pytorch/pytorch#163609) on parallelized module with DTensor input to get the joint graph - [Done] Apply min_cut_partitioner to get forward and backward graph module - [Done but Need verification] Apply prefect/bucketing graph passes on fw_gm and bw_gm to reorder/group the communication collectives - [Done] Run the joint graph with `aot_compile_joint_with_descriptors` - [Done] Region Inductor for FlexAttention, need to run on top of pytorch/pytorch#165202 and pytorch/pytorch#164776 Nest Step - Enable CudaGraph - Enable SimpleFSDP + EP - Showcase user annotation on MoE for dispatch, compute, combine region - Enable PP with custom Runner Issues - pytorch/pytorch#164559 - pytorch/pytorch#164543 - What's input order for aot_export_joint graph? using model.parameter() 's order as input seems wrong. Repro steps: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 Run with FlexAttention: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_paral lel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn Sample output: P1975157784: rank0_autograd_function_0fea2786.py P1975158481: rank1_autograd_function_28587623.py --------- Co-authored-by: Simon Fan <xmfan@meta.com>
register_constant(DTensorSpec) in the export test helper was permanently modifying global pytree state, causing subsequent compiled DTensor tests to fail. With DTensorSpec registered as a pytree constant, dynamo no longer decomposes glu into simpler ops that have sharding strategies, so aten.glu.default gets passed through to DTensor dispatch which can't handle it. Wrap in try/finally to deregister after use. Introduced in PR #163609 Authored with Claude. Pull Request resolved: #176128 Approved by: https://github.com/SherlockNoMad
I experimented with 3 paths to get joint graph for DTensorized module and input
Added test to guard them.
1 doesn't work, as bw graph region is missing from the joint graph.
I am leaning towards making 2 the recommended path.
If 2 doesn't work going forward, we can fallback to 3.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci