Skip to content

Support non-traceable Custom Ops #6979

@thong3le

Description

@thong3le

🚀 Feature

torch.export supports exporting blackbox custom ops, however, we fails to export it to StableHLO using exported_program_to_stablehlo API
https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#custom-ops

Motivation

if we have non-traceable python codes in the custom ops, we can't export it to stablehlo program. This means we won't be able to cover as much of the model when exporting through StableHLO.

Pitch

Here is the example pytorch codes

import torch
from torch.library import Library, impl, impl_abstract

m = Library("my_custom_library", "DEF")
m.define("custom_op(Tensor input) -> Tensor")

@impl(m, "custom_op", "CompositeExplicitAutograd")
def custom_op(x):
    raise Exception("DON'T GO HERE")
    return torch.relu(x)

@impl_abstract("my_custom_library::custom_op")
def custom_op_meta(x):
    return torch.empty_like(x)

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

em = torch.export.export(CustomOpExample(), (torch.randn(3, 3),))
em.graph_module.graph.print_tabular()

from torch_xla.stablehlo import exported_program_to_stablehlo
stablehlo_program = exported_program_to_stablehlo(em)
print(stablehlo_program.get_stablehlo_text())

As you can see, torch.export runs fine and give us this fx graph, without caring what is inside custom_op impl.

opcode         name       target                               args          kwargs
-------------  ---------  -----------------------------------  ------------  --------
placeholder    arg0_1     arg0_1                               ()            {}
call_function  sin        aten.sin.default                     (arg0_1,)     {}
call_function  custom_op  my_custom_library.custom_op.default  (sin,)        {}
call_function  cos        aten.cos.default                     (custom_op,)  {}
output         output     output                               ((cos,),)     {}

exported_program_to_stablehlo fails because it runs the custom_op and hits Exception.

When I comment out the line raise Exception("DON'T GO HERE"), exported_program_to_stablehlo works fine, however it traces into custom_op by converting relu to stablehlo.maximum,

module @IrToHlo.8 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
    %1 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %2 = stablehlo.maximum %1, %0 : tensor<3x3xf32>
    %3 = stablehlo.cosine %2 : tensor<3x3xf32>
    return %3 : tensor<3x3xf32>
  }
}

I wonder if we can support exporting blackbox custom ops all the way to StableHLO without executing the op. We want to see something like this in the output,

module @IrToHlo.8 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
    %1 = stablehlo.sine %arg0 : tensor<3x3xf32>
    %2 = stablehlo.custom_call {name = "my_custom_library.custom_op"}} : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %3 = stablehlo.cosine %2 : tensor<3x3xf32>
    return %3 : tensor<3x3xf32>
  }
}

Metadata

Metadata

Assignees

Labels

stablehloStableHLO related work

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions