Skip to content

[compile] Regional inductor compilation with fx.annotate#164776

Closed
anijain2305 wants to merge 22 commits intogh/anijain2305/900/basefrom
gh/anijain2305/900/head
Closed

[compile] Regional inductor compilation with fx.annotate#164776
anijain2305 wants to merge 22 commits intogh/anijain2305/900/basefrom
gh/anijain2305/900/head

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Oct 6, 2025

Stack from ghstack (oldest at bottom):

This PR introduces a way to compile a region of FX graph using fx.traceback.annotate.

UX

  1. In the user code, mark the region that you want to be compiled with inductor using with fx_traceback.annotate({"compile_with_inductor": 0}). As of now, we just rely on the string compile_with_inductor and ignore the integer. As the needs arise, we can update the logic.

Example

        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
  1. You have to instruct the compiler to use the annotations with compile_fx_annotated_nodes_with_inductor transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

  1. Fixable in short-term - You have to wrap the user code in torch.fx.traceback.preserve_node_meta to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

Implementation

  1. Relies on CapabilityBasedPartitioner to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
  2. Call torch._inductor.standalone_compile on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for torch__inductor_standalone_compile_inner

Forward graph

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)

Backward graph

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)

Some issue raised in the HOP meeting

  1. CSE will not differentiate different meta custom nodes and do wrong thing.
  2. SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
  3. What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
  4. What happens with the nesting of fx_traceback.annotate? Are there any ordering requirements?
  5. What are we going to use the annotations for?
    a) compile flex
    b) streams
    c) nn.Module info to organize MoE components for pipelining
    d) PP stages
    e) Rename graph nodes for more debugging
    f) No nested regional compile

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164776

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit a84e384 with merge base 83cbba8 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Oct 6, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
@anijain2305 anijain2305 requested a review from bdhirsh as a code owner October 6, 2025 22:29
anijain2305 added a commit that referenced this pull request Oct 6, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 6, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 7, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 7, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Oct 8, 2025

need pr desc

cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 8, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 8, 2025
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 8, 2025
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is 

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile



cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 13, 2025
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is 

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile



cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Oct 13, 2025
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
)

This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile

Pull Request resolved: pytorch#164776
Approved by: https://github.com/SherlockNoMad
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…orch#164776)"

This reverts commit 1e4c7df.

Reverted pytorch#164776 on behalf of https://github.com/malfet due to Looks like this one broke everything, not the top of the stack ([comment](pytorch#164776 (comment)))
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
)

This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile

Pull Request resolved: pytorch#164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: pytorch#165188
return gm


def _recursive_compile_fx_annotated_nodes_with_inductor(gm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maintain an invariant that gm only contains call_function, placeholder, and output``? This would make all passes easier to write.

chenmillie added a commit to chenmillie/pytorch that referenced this pull request Oct 27, 2025
Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
SherlockNoMad added a commit to pytorch/torchtitan that referenced this pull request Oct 28, 2025
)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
pytorch-bot bot pushed a commit that referenced this pull request Oct 28, 2025
…6339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
chenmillie added a commit to chenmillie/pytorch that referenced this pull request Oct 28, 2025
…orch#166339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
chenmillie added a commit to chenmillie/pytorch that referenced this pull request Oct 28, 2025
…orch#166339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
pytorch-bot bot pushed a commit that referenced this pull request Oct 29, 2025
…6339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Reviewed By: eellison, blaine-rister

Differential Revision: D85347587
pytorchmergebot pushed a commit that referenced this pull request Oct 29, 2025
…6339)

Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587

Pull Request resolved: #166339
Approved by: https://github.com/blaine-rister, https://github.com/eellison
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
…6339)

Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587

Pull Request resolved: #166339
Approved by: https://github.com/blaine-rister, https://github.com/eellison
jquesnelle pushed a commit to NousResearch/torchtitan that referenced this pull request Nov 10, 2025
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner

Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.

Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
@github-actions github-actions bot deleted the gh/anijain2305/900/head branch November 21, 2025 02:14
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…torch#1794)

This is an e2e prototype to run llama3-simplefsdp using export-y
aot_autograd workflow.

Setup: shard_dp = 2, tp = 4.

MVP
- [Done] Start with a simpleFSDP model, enable TP + FSDP 
- [Done] Apply
[aot_export_joing_with_descriptor](pytorch/pytorch#163609)
on parallelized module with DTensor input to get the joint graph
- [Done] Apply min_cut_partitioner to get forward and backward graph
module
- [Done but Need verification] Apply prefect/bucketing graph passes on
fw_gm and bw_gm to reorder/group the communication collectives
- [Done] Run the joint graph with `aot_compile_joint_with_descriptors`
- [Done] Region Inductor for FlexAttention, need to run on top of
pytorch/pytorch#165202 and
pytorch/pytorch#164776

Nest Step
- Enable CudaGraph
- Enable SimpleFSDP + EP 
- Showcase user annotation on MoE for dispatch, compute, combine region
- Enable PP with custom Runner


Issues
- pytorch/pytorch#164559
- pytorch/pytorch#164543
- What's input order for aot_export_joint graph? using model.parameter()
's order as input seems wrong.


Repro steps:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=4

Run with FlexAttention:
NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3
--compile.enable --parallelism.data_paral
lel_shard_degree=2 --parallelism.tensor_parallel_degree=4
--model.flavor=debugmodel_flex_attn

Sample output:
P1975157784: rank0_autograd_function_0fea2786.py
P1975158481: rank1_autograd_function_28587623.py

---------

Co-authored-by: Simon Fan <xmfan@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: dynamo module: inductor release notes: fx release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants