🚀 Feature
We propose PyTorch/XLA to leverage CapabilityBasedPartitioner from torch.fx.passes.infra.partitioner to support CPU fallback for the dynamo-xla bridge. The CapabilityBasedPartitioner a tool in torch.fx to partition the FX graph into "backend" parts, given a list of supported operators. The partitioner tries to form the largest subgraphs that only contain the supported ops.
Motivation
In the current dynamo-xla integration, the dynamo bridge passes down a fx module and expects the backend to return a single function/hash representing the whole model. However, in the cases of that fx model containing the ops that PyTorch/XLA does not support, PyTorch/XLA will fallback at unsupported ops and generate more than one graph. In such cases, the current torch_xla_bridge will crash.
This issue was originally tracked in an upstream issue feature request.
Pitch
In order to support CPU fallback for the dynamo-xla bridge, we propose PyTorch/XLA to leverage CapabilityBasedPartitioner from torch.fx.passes.infra.partitioner. The CapabilityBasedPartitioner a tool in torch.fx to partition the FX graph into "backend" parts given a list of supported operators. This partitioner tries to forms the largest subgraphs that only contain the supported ops.
Currently, other backends such as nvFuser and onnxruntimer leverage CapabilityBasedPartitioner to achieve a similar use case.
CapabilityBasedPartitioner
Example usage of CapabilityBasedPartitioner is shown below. First, we need to initialize CapabilityBasedPartitioner with the FX graph module and a list of supports ops. Then we can call the partitioner's partition_and_fuse() method to produce the fused_graph_module.
partitioner = CapabilityBasedPartitioner(graph_module, self.supported_ops)
fused_graph_module = partitioner.partition_and_fuse()
Now, we need to override fused_module's __call__() function with our custom backend function. This looks like:
for node in fused_graph_module.graph.nodes:
# Identify fused submodule
if "fused_" in node.name:
fused_module = getattr(fused_graph_module, node.name)
fused_module._wrapped_call = my_backend_function
The supported_ops is provided by overriding the torch.fx.passes.operator_support.OperatorSupport class to provide a string dictionary of supported ops and overriding the OperatorSupport.is_node_supported function. Example of usage is shown below:
class MyOperatorSupport(OperatorSupport):
def __init__(self):
support_dict = {
# List of supports ops
"torch.ops.aten.add": None,
"torch.ops.aten.sub": None,
}
super().__init__(support_dict)
def is_node_supported(
self, submodules: t.Mapping[str, Module], node: Node
) -> bool:
# The FX subgraph should be purely functional
if node.op not in CALLABLE_NODE_OPS:
return False
# ops in supported_dict doesn't have overload name
# use overloadpacket's qualified_name for OpOverload
if isinstance(node.target, OpOverload):
target = _get_qualified_name(node.target.overloadpacket)
if target in self._support_dict:
return True
return super().is_node_supported(submodules, node)
Caveats
Note that PyTorch/XLA also falls back based on input types/values on certain ops. For example, PyTorch/XLA currently only supports addmm op with alpha and beta values of 1 and will fall back on other values. The CapabilityBasedPartitioner currently accepts a list of supported op names only and does not provide additional details to fall back on (input shapes, values, etc). We'll need to come up with a separate way to conditionally fallback based on input values.
[Update as of March 9th, 2023] One possible solution could be to dry-run our C++ op in the overridden is_node_supported function. This way, we can quickly check if the op needs fall back based on input values/types.
Additional context
cc @JackCaoG @alanwaketan
🚀 Feature
We propose PyTorch/XLA to leverage
CapabilityBasedPartitionerfromtorch.fx.passes.infra.partitionerto support CPU fallback for the dynamo-xla bridge. TheCapabilityBasedPartitionera tool in torch.fx to partition the FX graph into "backend" parts, given a list of supported operators. The partitioner tries to form the largest subgraphs that only contain the supported ops.Motivation
In the current dynamo-xla integration, the dynamo bridge passes down a fx module and expects the backend to return a single function/hash representing the whole model. However, in the cases of that fx model containing the ops that PyTorch/XLA does not support, PyTorch/XLA will fallback at unsupported ops and generate more than one graph. In such cases, the current
torch_xla_bridgewill crash.This issue was originally tracked in an upstream issue feature request.
Pitch
In order to support CPU fallback for the dynamo-xla bridge, we propose PyTorch/XLA to leverage
CapabilityBasedPartitionerfromtorch.fx.passes.infra.partitioner. TheCapabilityBasedPartitionera tool in torch.fx to partition the FX graph into "backend" parts given a list of supported operators. This partitioner tries to forms the largest subgraphs that only contain the supported ops.Currently, other backends such as
nvFuserandonnxruntimerleverageCapabilityBasedPartitionerto achieve a similar use case.CapabilityBasedPartitioner
Example usage of
CapabilityBasedPartitioneris shown below. First, we need to initializeCapabilityBasedPartitionerwith the FX graph module and a list of supports ops. Then we can call the partitioner'spartition_and_fuse()method to produce thefused_graph_module.Now, we need to override fused_module's
__call__()function with our custom backend function. This looks like:The
supported_opsis provided by overriding thetorch.fx.passes.operator_support.OperatorSupportclass to provide a string dictionary of supported ops and overriding theOperatorSupport.is_node_supportedfunction. Example of usage is shown below:Caveats
Note that PyTorch/XLA also falls back based on input types/values on certain ops. For example, PyTorch/XLA currently only supports
addmmop withalphaandbetavalues of 1 and will fall back on other values. TheCapabilityBasedPartitionercurrently accepts a list of supported op names only and does not provide additional details to fall back on (input shapes, values, etc). We'll need to come up with a separate way to conditionally fallback based on input values.[Update as of March 9th, 2023] One possible solution could be to dry-run our C++ op in the overridden
is_node_supportedfunction. This way, we can quickly check if the op needs fall back based on input values/types.Additional context
CapabilityBasedPartitioner: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/infra/partitioner.py.nvFuserusesCapabilityBasedPartitioner: https://github.com/YLGH/pytorch/blob/c9a0204ef4fb8edd29aeaebc5c74867fa1093a48/torch/fx/passes/backends/nvfuser.pyonnxruntimerusesCapabilityBasedPartitioner: https://github.com/ramkrishna2910/onnxruntime/blob/master/orttraining/orttraining/python/training/torchdynamo/ort_backend.pycc @JackCaoG @alanwaketan