Skip to content

PyOperator.fallthrough(DispatchKey.PythonDispatcher) will cause infinite recursion during redispatch. #89037

@zhxchen17

Description

@zhxchen17

🐛 Describe the bug

I found that providing a fallthrough impl on PythonDispatcher key for a PyOperator will cause the handler to infinitly redispatch to PythonDispatcher itself. Seems there's an issue between fallthrough the PythonDispatcher key here.

(discovered in PR: #88767)

from torch._C import DispatchKey
import torch
from functorch.experimental.cond import cond
from torch.fx.experimental.proxy_tensor import make_fx

# A hack to get DispatchKey.PythonDispatcher without updating pybind and rebuilding pytorch.
# Should really be just: DispatchKey.PythonDispatcher
# Placed here for quick repro.
def DispatchKey_PythonDispatcher():
    return torch._C._dispatch_keyset_full_after(DispatchKey.CPU).highestPriorityTypeId()

def true_fn(x):
    return x.sin()


def false_fn(x):
    return x.cos()


def f(x, y):
    return cond(y, true_fn, false_fn, [x])

# without this fallthrough, we will simply miss an impl for this key, which will error during make(tracing_mode="symbolic").
cond.fallthrough(DispatchKey_PythonDispatcher())

graph = make_fx(f, tracing_mode="symbolic")(torch.ones(3, 2), torch.tensor(False))

Error:

maximum recursion depth exceeded
---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-11-0bd412362562> in <module>
     21 cond.fallthrough(DispatchKey_PythonDispatcher())
     22 
---> 23 graph = make_fx(f, tracing_mode="symbolic")(torch.ones(3, 2), torch.tensor(False))
/torch/fx/experimental/proxy_tensor.py in wrapped(*args)
    650         with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
    651              sym_mode, proxy_mode, disable_autocast_cache():  # type: ignore[attr-defined]
--> 652             t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
    653 
    654         # TODO: kind of a bad way to do it, should maybe figure out a better way
/torch/fx/experimental/proxy_tensor.py in dispatch_trace(root, tracer, concrete_args)
    410         concrete_args: Optional[Tuple[Any, ...]] = None,
    411 ) -> GraphModule:
--> 412     graph = tracer.trace(root, concrete_args)
    413     name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    414     return GraphModule(tracer.root, graph, name)
/torch/fx/_symbolic_trace.py in trace(self, root, concrete_args)
    737                     "output",
    738                     "output",
--> 739                     (self.create_arg(fn(*args)),),
    740                     {},
    741                     type_expr=fn.__annotations__.get("return", None),
/torch/fx/experimental/proxy_tensor.py in wrapped(*proxies)
    424         track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
    425 
--> 426         out = f(*tensors)
    427         out = pytree.tree_map_only(
    428             torch.Tensor,
<ipython-input-11-0bd412362562> in f(x, y)
     17 
     18 def f(x, y):
---> 19     return cond(y, true_fn, false_fn, [x])
     20 
     21 cond.fallthrough(DispatchKey_PythonDispatcher())
/torch/_ops.py in __call__(self, *args, **kwargs)
    171 
    172         dispatch_key_set = _compute_keyset(args, kwargs)
--> 173         return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
    174 
    175     def name(self):
/torch/_ops.py in dispatch(self, dispatch_key, *args, **kwargs)
    161 
    162         assert dispatch_key in self.table
--> 163         return self.table[dispatch_key](*args, **kwargs)
    164 
    165     def __call__(self, *args, **kwargs):
/torch/_ops.py in inner(*args, **kwargs)
    184                 args, kwargs
    185             )
--> 186             return self.dispatch(
    187                 all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
    188             )
... last 2 frames repeated, from the frame below ...
/torch/_ops.py in dispatch(self, dispatch_key, *args, **kwargs)
    161 
    162         assert dispatch_key in self.table
--> 163         return self.table[dispatch_key](*args, **kwargs)
    164 
    165     def __call__(self, *args, **kwargs):
RecursionError: maximum recursion depth exceeded

cc @ezyang @voznesenskym

Versions

pytorch main branch

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: python dispatchertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions