Skip to content

[RFC] Supporting CPU fallback for the dynamo-xla bridge #4742

@wonjoo-wj

Description

@wonjoo-wj

🚀 Feature

We propose PyTorch/XLA to leverage CapabilityBasedPartitioner from torch.fx.passes.infra.partitioner to support CPU fallback for the dynamo-xla bridge. The CapabilityBasedPartitioner a tool in torch.fx to partition the FX graph into "backend" parts, given a list of supported operators. The partitioner tries to form the largest subgraphs that only contain the supported ops.

Motivation

In the current dynamo-xla integration, the dynamo bridge passes down a fx module and expects the backend to return a single function/hash representing the whole model. However, in the cases of that fx model containing the ops that PyTorch/XLA does not support, PyTorch/XLA will fallback at unsupported ops and generate more than one graph. In such cases, the current torch_xla_bridge will crash.

This issue was originally tracked in an upstream issue feature request.

Pitch

In order to support CPU fallback for the dynamo-xla bridge, we propose PyTorch/XLA to leverage CapabilityBasedPartitioner from torch.fx.passes.infra.partitioner. The CapabilityBasedPartitioner a tool in torch.fx to partition the FX graph into "backend" parts given a list of supported operators. This partitioner tries to forms the largest subgraphs that only contain the supported ops.

Currently, other backends such as nvFuser and onnxruntimer leverage CapabilityBasedPartitioner to achieve a similar use case.

CapabilityBasedPartitioner

Example usage of CapabilityBasedPartitioner is shown below. First, we need to initialize CapabilityBasedPartitioner with the FX graph module and a list of supports ops. Then we can call the partitioner's partition_and_fuse() method to produce the fused_graph_module.

partitioner = CapabilityBasedPartitioner(graph_module, self.supported_ops)
fused_graph_module = partitioner.partition_and_fuse()

Now, we need to override fused_module's __call__() function with our custom backend function. This looks like:

for node in fused_graph_module.graph.nodes:
    # Identify fused submodule
    if "fused_" in node.name:
        fused_module = getattr(fused_graph_module, node.name)
        fused_module._wrapped_call = my_backend_function

The supported_ops is provided by overriding the torch.fx.passes.operator_support.OperatorSupport class to provide a string dictionary of supported ops and overriding the OperatorSupport.is_node_supported function. Example of usage is shown below:

class MyOperatorSupport(OperatorSupport):
    def __init__(self):
        support_dict = {
            # List of supports ops
            "torch.ops.aten.add": None,
            "torch.ops.aten.sub": None,
        }

        super().__init__(support_dict)

    def is_node_supported(
        self, submodules: t.Mapping[str, Module], node: Node
    ) -> bool:

        # The FX subgraph should be purely functional
        if node.op not in CALLABLE_NODE_OPS:
            return False

        # ops in supported_dict doesn't have overload name
        # use overloadpacket's qualified_name for OpOverload
        if isinstance(node.target, OpOverload):
            target = _get_qualified_name(node.target.overloadpacket)
            if target in self._support_dict:
                return True

        return super().is_node_supported(submodules, node)

Caveats

Note that PyTorch/XLA also falls back based on input types/values on certain ops. For example, PyTorch/XLA currently only supports addmm op with alpha and beta values of 1 and will fall back on other values. The CapabilityBasedPartitioner currently accepts a list of supported op names only and does not provide additional details to fall back on (input shapes, values, etc). We'll need to come up with a separate way to conditionally fallback based on input values.

[Update as of March 9th, 2023] One possible solution could be to dry-run our C++ op in the overridden is_node_supported function. This way, we can quickly check if the op needs fall back based on input values/types.

Additional context

cc @JackCaoG @alanwaketan

Metadata

Metadata

Assignees

Labels

RFCnostaleDo not consider for staleness

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions