I found that providing a fallthrough impl on PythonDispatcher key for a PyOperator will cause the handler to infinitly redispatch to PythonDispatcher itself. Seems there's an issue between fallthrough the PythonDispatcher key here.
maximum recursion depth exceeded
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
<ipython-input-11-0bd412362562> in <module>
21 cond.fallthrough(DispatchKey_PythonDispatcher())
22
---> 23 graph = make_fx(f, tracing_mode="symbolic")(torch.ones(3, 2), torch.tensor(False))
/torch/fx/experimental/proxy_tensor.py in wrapped(*args)
650 with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
651 sym_mode, proxy_mode, disable_autocast_cache(): # type: ignore[attr-defined]
--> 652 t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
653
654 # TODO: kind of a bad way to do it, should maybe figure out a better way
/torch/fx/experimental/proxy_tensor.py in dispatch_trace(root, tracer, concrete_args)
410 concrete_args: Optional[Tuple[Any, ...]] = None,
411 ) -> GraphModule:
--> 412 graph = tracer.trace(root, concrete_args)
413 name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
414 return GraphModule(tracer.root, graph, name)
/torch/fx/_symbolic_trace.py in trace(self, root, concrete_args)
737 "output",
738 "output",
--> 739 (self.create_arg(fn(*args)),),
740 {},
741 type_expr=fn.__annotations__.get("return", None),
/torch/fx/experimental/proxy_tensor.py in wrapped(*proxies)
424 track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
425
--> 426 out = f(*tensors)
427 out = pytree.tree_map_only(
428 torch.Tensor,
<ipython-input-11-0bd412362562> in f(x, y)
17
18 def f(x, y):
---> 19 return cond(y, true_fn, false_fn, [x])
20
21 cond.fallthrough(DispatchKey_PythonDispatcher())
/torch/_ops.py in __call__(self, *args, **kwargs)
171
172 dispatch_key_set = _compute_keyset(args, kwargs)
--> 173 return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
174
175 def name(self):
/torch/_ops.py in dispatch(self, dispatch_key, *args, **kwargs)
161
162 assert dispatch_key in self.table
--> 163 return self.table[dispatch_key](*args, **kwargs)
164
165 def __call__(self, *args, **kwargs):
/torch/_ops.py in inner(*args, **kwargs)
184 args, kwargs
185 )
--> 186 return self.dispatch(
187 all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
188 )
... last 2 frames repeated, from the frame below ...
/torch/_ops.py in dispatch(self, dispatch_key, *args, **kwargs)
161
162 assert dispatch_key in self.table
--> 163 return self.table[dispatch_key](*args, **kwargs)
164
165 def __call__(self, *args, **kwargs):
RecursionError: maximum recursion depth exceeded
🐛 Describe the bug
I found that providing a fallthrough impl on PythonDispatcher key for a PyOperator will cause the handler to infinitly redispatch to PythonDispatcher itself. Seems there's an issue between fallthrough the PythonDispatcher key here.
(discovered in PR: #88767)
Error:
cc @ezyang @voznesenskym
Versions
pytorch main branch