🚀 Feature
torch.export supports exporting blackbox custom ops, however, we fails to export it to StableHLO using exported_program_to_stablehlo API
https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#custom-ops
Motivation
if we have non-traceable python codes in the custom ops, we can't export it to stablehlo program. This means we won't be able to cover as much of the model when exporting through StableHLO.
Pitch
Here is the example pytorch codes
import torch
from torch.library import Library, impl, impl_abstract
m = Library("my_custom_library", "DEF")
m.define("custom_op(Tensor input) -> Tensor")
@impl(m, "custom_op", "CompositeExplicitAutograd")
def custom_op(x):
raise Exception("DON'T GO HERE")
return torch.relu(x)
@impl_abstract("my_custom_library::custom_op")
def custom_op_meta(x):
return torch.empty_like(x)
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
em = torch.export.export(CustomOpExample(), (torch.randn(3, 3),))
em.graph_module.graph.print_tabular()
from torch_xla.stablehlo import exported_program_to_stablehlo
stablehlo_program = exported_program_to_stablehlo(em)
print(stablehlo_program.get_stablehlo_text())
As you can see, torch.export runs fine and give us this fx graph, without caring what is inside custom_op impl.
opcode name target args kwargs
------------- --------- ----------------------------------- ------------ --------
placeholder arg0_1 arg0_1 () {}
call_function sin aten.sin.default (arg0_1,) {}
call_function custom_op my_custom_library.custom_op.default (sin,) {}
call_function cos aten.cos.default (custom_op,) {}
output output output ((cos,),) {}
exported_program_to_stablehlo fails because it runs the custom_op and hits Exception.
When I comment out the line raise Exception("DON'T GO HERE"), exported_program_to_stablehlo works fine, however it traces into custom_op by converting relu to stablehlo.maximum,
module @IrToHlo.8 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
%0 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
%1 = stablehlo.sine %arg0 : tensor<3x3xf32>
%2 = stablehlo.maximum %1, %0 : tensor<3x3xf32>
%3 = stablehlo.cosine %2 : tensor<3x3xf32>
return %3 : tensor<3x3xf32>
}
}
I wonder if we can support exporting blackbox custom ops all the way to StableHLO without executing the op. We want to see something like this in the output,
module @IrToHlo.8 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
%0 = stablehlo.constant dense<0.000000e+00> : tensor<3x3xf32>
%1 = stablehlo.sine %arg0 : tensor<3x3xf32>
%2 = stablehlo.custom_call {name = "my_custom_library.custom_op"}} : (tensor<3x3xf32>) -> tensor<3x3xf32>
%3 = stablehlo.cosine %2 : tensor<3x3xf32>
return %3 : tensor<3x3xf32>
}
}
🚀 Feature
torch.exportsupports exporting blackbox custom ops, however, we fails to export it to StableHLO usingexported_program_to_stablehloAPIhttps://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#custom-ops
Motivation
if we have non-traceable python codes in the custom ops, we can't export it to stablehlo program. This means we won't be able to cover as much of the model when exporting through StableHLO.
Pitch
Here is the example pytorch codes
As you can see,
torch.exportruns fine and give us this fx graph, without caring what is insidecustom_opimpl.exported_program_to_stablehlofails because it runs thecustom_opand hitsException.When I comment out the line
raise Exception("DON'T GO HERE"),exported_program_to_stablehloworks fine, however it traces intocustom_opby convertingrelutostablehlo.maximum,I wonder if we can support exporting blackbox custom ops all the way to StableHLO without executing the op. We want to see something like this in the output,