Skip to content

Support export custom op to stablehlo custom call#7017

Merged
lsy323 merged 4 commits intomasterfrom
lsiyuan/shlo-cc
May 2, 2024
Merged

Support export custom op to stablehlo custom call#7017
lsy323 merged 4 commits intomasterfrom
lsiyuan/shlo-cc

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented May 2, 2024

This PR resolves #6979

Support export custom op to stablehlo custom call, there are 2 user journeys

  1. Register customer op to be wraped as stablehlo custom call during export, via StableHLOExportOptions
from torch_xla.stablehlo import allow_custom_op_in_graph

m = Library("my_custom_library", "DEF")
m.define("custom_op(Tensor input) -> Tensor")

class M(torch.nn.Module):

  def forward(self, x):
    x = torch.sin(x)
    x = torch.ops.my_custom_library.custom_op(x)
    x = torch.cos(x)
    return x

options = StableHLOExportOptions()
options.custom_ops_allowed_in_graph.add("my_custom_library")
ep = torch.export.export(M(), (torch.randn(3, 3),))
shlo_module = exported_program_to_stablehlo(ep, options)

Example output

module @IrToHlo.6 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
    %0 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %1 = stablehlo.custom_call @my_custom_library.custom_op.default(%0) {api_version = 0 : i32} : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %2 = stablehlo.cosine %1 : tensor<3x3xf32>
    return %2 : tensor<3x3xf32>
  }
}
  1. Call the stablehlo_custom_call API in the XLA dispatch impl of the custom op, with more control on the generated custom call
m.define("custom_op3(Tensor input) -> Tensor")

@impl(m, "custom_op3", "XLA")
def custom_op3_xla(x):
  res = stablehlo_custom_call((x,), "custom_op3", [x.shape[1:]],
                                  [torch.int8], True, "backend_config", 1)
  return res

class M(torch.nn.Module):

  def forward(self, x):
    x = torch.sin(x)
    x = torch.ops.my_custom_library.custom_op3(x)
    x = x + 1
    return x

ep = torch.export.export(M(), (torch.randn(3, 3),))
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()

example ouptut

module @IrToHlo.10 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3xi8> {
    %c = stablehlo.constant dense<1> : tensor<3xi8>
    %0 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %1 = stablehlo.custom_call @custom_op3(%0) {backend_config = "backend_config", has_side_effect = true} : (tensor<3x3xf32>) -> tensor<3xi8>
    %2 = stablehlo.add %1, %c : tensor<3xi8>
    return %2 : tensor<3xi8>
  }
}

Test: Added unit tests to cover multiple inputs/outputs, and 2 user jouneys.

Future work:

  1. Register stablehlo_custom_call as a torch op, currently the op schema failed with parameter type of ScalarType[].
  2. Merge the tpu_custom_call code path and custom_call code path. Currently there are duplicated code in 2 paths. tpu_custom_call should be a special use case of custom_call. Need to extend custom_call to support xla::CustomCallWithLayout cc @alanwaketan

cc @GleasonK

@lsy323 lsy323 requested a review from qihqi May 2, 2024 05:00
Comment thread torch_xla/stablehlo.py Outdated
Comment thread torch_xla/stablehlo.py Outdated
@lsy323 lsy323 requested a review from qihqi May 2, 2024 20:50
@lsy323 lsy323 merged commit 666eccb into master May 2, 2024
@lsy323 lsy323 deleted the lsiyuan/shlo-cc branch May 2, 2024 22:37
@xinli-sw
Copy link
Copy Markdown

xinli-sw commented Dec 4, 2024

Hi @lsy323, thanks for the amazing work!

We are investigating issue #8385 and found that output_operand_aliasing here is silently dropped when calling the custom XLA OP.

Do you know if there is a way to restore aliasing information here? If not, do you recall what the limitation is that caused the asliasing information to be dropped here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support non-traceable Custom Ops

3 participants