Skip to content

Do not treat ReinterpretView as a realized node#159920

Closed
shengfukevin wants to merge 1 commit intopytorch:mainfrom
shengfukevin:export-D79692316
Closed

Do not treat ReinterpretView as a realized node#159920
shengfukevin wants to merge 1 commit intopytorch:mainfrom
shengfukevin:export-D79692316

Conversation

@shengfukevin
Copy link
Contributor

@shengfukevin shengfukevin commented Aug 6, 2025

Summary:
Do not treat ReinterpretView as a realized node

Function [gather_origins](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L888&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example:

@torchdynamo.optimize("inductor")
def fn(input_data, weight):
    normalized_input = input_data * weight.unsqueeze(0)
    return normalized_input
input_data = torch.randn(4272, 192, requires_grad=True).to(device)
weight = torch.randn(192, requires_grad=True).to(device)
fn(input_data, weight)

The original FX graph returned in [get_kernel_metadata](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L723&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul
The unsqueeze op is missing.

With this DIFF, the new FX graph is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {})
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159920

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 503143e with merge base 231c722 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79692316

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79692316

@shengfukevin shengfukevin self-assigned this Aug 7, 2025
@shengfukevin shengfukevin requested a review from mlazos August 7, 2025 18:27
shengfukevin added a commit to shengfukevin/pytorch that referenced this pull request Aug 7, 2025
Summary:

Do not treat ReinterpretView as a realized node

Function gather_origins (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example:

`torchdynamo.optimize("inductor")`
`def fn(input_data, weight):`
`.   normalized_input = input_data * weight.unsqueeze(0)`
`.   return normalized_input`

`input_data = torch.randn(4272, 192, requires_grad=True).to(device)`
`weight = torch.randn(192, requires_grad=True).to(device)`
`fn(input_data, weight)`

The original FX graph returned in get_kernel_metadata (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723) is the following:

%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

The unsueeze op is missing.

With this DIFF, the new FX graph is the following:

%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {})
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

Test Plan:
buck2 run mode/opt caffe2/test:test_profiler_cuda  -- profiler.test_execution_trace.TestExecutionTraceCUDA.test_triton_fx_graph_with_et_cuda

Rollback Plan:

Differential Revision: D79692316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79692316

@shengfukevin
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@shengfukevin
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 8, 2025
@shengfukevin
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Summary:

Do not treat ReinterpretView as a realized node

Function gather_origins (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example:

`torchdynamo.optimize("inductor")`
`def fn(input_data, weight):`
`.   normalized_input = input_data * weight.unsqueeze(0)`
`.   return normalized_input`

`input_data = torch.randn(4272, 192, requires_grad=True).to(device)`
`weight = torch.randn(192, requires_grad=True).to(device)`
`fn(input_data, weight)`

The original FX graph returned in get_kernel_metadata (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723) is the following:

%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

The unsueeze op is missing.

With this DIFF, the new FX graph is the following:

%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {})
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

Test Plan:
buck2 run mode/opt caffe2/test:test_profiler_cuda  -- profiler.test_execution_trace.TestExecutionTraceCUDA.test_triton_fx_graph_with_et_cuda

Rollback Plan:

Differential Revision: D79692316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79692316

@shengfukevin
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Summary:
Do not treat ReinterpretView as a realized node

Function [gather_origins](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L888&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example:

```
@torchdynamo.optimize("inductor")
def fn(input_data, weight):
    normalized_input = input_data * weight.unsqueeze(0)
    return normalized_input
input_data = torch.randn(4272, 192, requires_grad=True).to(device)
weight = torch.randn(192, requires_grad=True).to(device)
fn(input_data, weight)
```

The original FX graph returned in [get_kernel_metadata](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L723&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul
The unsqueeze op is missing.

With this DIFF, the new FX graph is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {})
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

Pull Request resolved: pytorch#159920
Approved by: https://github.com/mlazos
pytorch-bot bot pushed a commit that referenced this pull request Aug 12, 2025
Summary:
Fix unit test test_equivalent_template_code

#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711
pytorch-bot bot pushed a commit that referenced this pull request Aug 13, 2025
Summary:

Fix unit test test_equivalent_template_code

#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Reviewed By: clee2000

Differential Revision: D80101711
shengfukevin added a commit to shengfukevin/pytorch that referenced this pull request Aug 13, 2025
Summary:
Pull Request resolved: pytorch#160432

Fix unit test test_equivalent_template_code

pytorch#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Reviewed By: clee2000

Differential Revision: D80101711
shengfukevin added a commit to shengfukevin/pytorch that referenced this pull request Aug 13, 2025
Summary:

Fix unit test test_equivalent_template_code

pytorch#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Reviewed By: clee2000

Differential Revision: D80101711
pytorchmergebot pushed a commit that referenced this pull request Aug 13, 2025
Summary:
Fix unit test test_equivalent_template_code

#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711

Pull Request resolved: #160432
Approved by: https://github.com/clee2000
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
Summary:
Fix unit test test_equivalent_template_code

#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711

Pull Request resolved: #160432
Approved by: https://github.com/clee2000
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
Summary:
Fix unit test test_equivalent_template_code

#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711

Pull Request resolved: #160432
Approved by: https://github.com/clee2000
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Summary:
Fix unit test test_equivalent_template_code

pytorch#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711

Pull Request resolved: pytorch#160432
Approved by: https://github.com/clee2000
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary:
Do not treat ReinterpretView as a realized node

Function [gather_origins](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L888&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example:

```
@torchdynamo.optimize("inductor")
def fn(input_data, weight):
    normalized_input = input_data * weight.unsqueeze(0)
    return normalized_input
input_data = torch.randn(4272, 192, requires_grad=True).to(device)
weight = torch.randn(192, requires_grad=True).to(device)
fn(input_data, weight)
```

The original FX graph returned in [get_kernel_metadata](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L723&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul
The unsqueeze op is missing.

With this DIFF, the new FX graph is the following:
%primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2]
%primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1]
%unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {})
%mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {})
return %mul

Pull Request resolved: pytorch#159920
Approved by: https://github.com/mlazos
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary:
Fix unit test test_equivalent_template_code

pytorch#159920 treats  ReinterpretView as a not-realized node when searching FX origin nodes for fused triton kernel. In test_equivalent_template_code, there is a transpose node (which is a ReinterpretView) before matmul. It was not in FX graph segment before PR 159920. FX origin nodes are used to define the name of triton kernel. That is the reason test_equivalent_template_code failed with PR 159920 since it uses hard-coded triton kernel name to check the result. The fix is to update the triton kernel name in the unit test.

Test Plan:
buck2 run mode/opt caffe2/test/inductor:benchmark_fusion -- caffe2.test.inductor.test_benchmark_fusion.BenchmarkMultiTemplateFusionCudaTest

Rollback Plan:

Differential Revision: D80101711

Pull Request resolved: pytorch#160432
Approved by: https://github.com/clee2000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants