Skip to content

Add option to export FX Node metadata to StableHLO#7046

Merged
lsy323 merged 5 commits intomasterfrom
lsiyuan/export-node-metadata
May 14, 2024
Merged

Add option to export FX Node metadata to StableHLO#7046
lsy323 merged 5 commits intomasterfrom
lsiyuan/export-node-metadata

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented May 10, 2024

This fixes #7014

Example usage

ep = torch.export.export(M(), args)
export_options = StableHLOExportOptions()
export_options.export_node_metadata = True
shlo = exported_program_to_stablehlo(ep, options=export_options)

The fx.node.meta(Including stack_trace, nn_module_stack and source_fn_stack) will be serialized to JSON, stored in the NameLoc in StableHLO

Example output

module @IrToHlo.24 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32> [unknown], %arg1: tensor<10x16xf32> [unknown], %arg2: tensor<16xf32> [unknown], %arg3: tensor<16x4xf32> [unknown], %arg4: tensor<2x4xf32> [unknown]) -> tensor<2x10xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<2x10xf32> [unknown]
    %0 = stablehlo.transpose %arg3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4,16]{0,1}"} : (tensor<16x4xf32>) -> tensor<4x16xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 146, in forward\\n    x = self.fc1(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc1,fc1,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc1,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_1;builtin_function_or_method.linear\22}aten__permute"
    %1 = stablehlo.dot_general %arg4, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x4xf32>, tensor<4x16xf32>) -> tensor<2x16xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 146, in forward\\n    x = self.fc1(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc1,fc1,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc1,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_1;builtin_function_or_method.linear\22}aten__addmm"
    %2 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<16xf32>) -> tensor<2x16xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 146, in forward\\n    x = self.fc1(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc1,fc1,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc1,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_1;builtin_function_or_method.linear\22}aten__addmm"
    %3 = stablehlo.add %1, %2 : tensor<2x16xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 146, in forward\\n    x = self.fc1(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc1,fc1,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc1,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_1;builtin_function_or_method.linear\22}aten__addmm"
    %4 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[16,10]{0,1}"} : (tensor<10x16xf32>) -> tensor<16x10xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 147, in forward\\n    x = self.fc2(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc2,fc2,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc2,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_2;builtin_function_or_method.linear\22}aten__permute"
    %5 = stablehlo.dot_general %3, %4, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x16xf32>, tensor<16x10xf32>) -> tensor<2x10xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 147, in forward\\n    x = self.fc2(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc2,fc2,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc2,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_2;builtin_function_or_method.linear\22}aten__addmm"
    %6 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<10xf32>) -> tensor<2x10xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 147, in forward\\n    x = self.fc2(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc2,fc2,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc2,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_2;builtin_function_or_method.linear\22}aten__addmm"
    %7 = stablehlo.add %5, %6 : tensor<2x10xf32> "{\22stack_trace\22: \22  File \\\22/home/lsiyuan/work/pytorch/xla/test/stablehlo/test_exports.py\\\22, line 147, in forward\\n    x = self.fc2(x)\\n\22, \22nn_module_stack\22: \22L__self__,,__main__.ExportTest.test_export_node_metadata.<locals>.M;L__self___fc2,fc2,torch.nn.modules.linear.Linear\22, \22source_fn_stack\22: \22l__self___fc2,torch.nn.modules.linear.Linear\22, \22torch_fn\22: \22linear_2;builtin_function_or_method.linear\22}aten__addmm"
    %8 = stablehlo.maximum %7, %cst : tensor<2x10xf32> "{}aten__relu"
    return %8 : tensor<2x10xf32> [unknown]
  } [unknown]
} [unknown]

@lsy323 lsy323 changed the title Add option to export node metadata to StableHLO Add option to export FX node metadata to StableHLO May 10, 2024
@lsy323 lsy323 changed the title Add option to export FX node metadata to StableHLO Add option to export FX Node metadata to StableHLO May 10, 2024
Comment thread torch_xla/stablehlo.py
Comment thread torch_xla/stablehlo.py Outdated
Comment thread torch_xla/stablehlo.py
@lsy323 lsy323 requested a review from chunnienc May 10, 2024 22:52
@lsy323 lsy323 added the stablehlo StableHLO related work label May 10, 2024
@lsy323 lsy323 force-pushed the lsiyuan/export-node-metadata branch 2 times, most recently from b067526 to cbcab73 Compare May 13, 2024 19:54
Siyuan Liu added 3 commits May 13, 2024 23:03
@lsy323 lsy323 force-pushed the lsiyuan/export-node-metadata branch from cbcab73 to 7a0cea0 Compare May 13, 2024 23:03
@lsy323 lsy323 merged commit 6f392cc into master May 14, 2024
@lsy323 lsy323 deleted the lsiyuan/export-node-metadata branch May 14, 2024 16:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stablehlo StableHLO related work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Export debug information to StableHLO

3 participants