Skip to content

Commit d7a9986

Browse files
angelayipytorchmergebot
authored andcommitted
Remove hacky python dispatcher fallthrough
1 parent c2d7508 commit d7a9986

2 files changed

Lines changed: 2 additions & 12 deletions

File tree

functorch/experimental/_cond.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,6 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
148148
return true_outs
149149

150150

151-
# We cannot directly call fallthrough here due to issue #89037.
152-
@cond.py_impl(DispatchKey.PythonDispatcher)
153-
def cond_python_dispatcher(*args):
154-
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
155-
return cond(*args)
156-
157151

158152
def _has_potential_branch_input_mutation(branch, inputs):
159153
"""
@@ -242,6 +236,7 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
242236
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
243237

244238
# TODO(voz): Make this automatic for keys, this is very ugly atm
239+
cond.fallthrough(DispatchKey.PythonDispatcher)
245240
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
246241
cond.fallthrough(DispatchKey.ADInplaceOrView)
247242
cond.fallthrough(DispatchKey.BackendSelect)

functorch/experimental/_map.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,6 @@ def map_fake_tensor_mode(f, xs, *args):
9191
outs = [f(x, *args) for x in xs]
9292
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
9393

94-
# We cannot directly call fallthrough here due to issue #89037.
95-
@map.py_impl(DispatchKey.PythonDispatcher)
96-
def map_python_dispatcher(*args):
97-
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
98-
return map(*args)
99-
10094
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
10195
def map_functionalize(interpreter, f, xs, *args):
10296
"""
@@ -128,6 +122,7 @@ def map_functionalize(interpreter, f, xs, *args):
128122
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
129123

130124
# TODO(voz) Make this automatic for keys, this is very ugly atm
125+
map.fallthrough(DispatchKey.PythonDispatcher)
131126
map.fallthrough(DispatchKey.PythonTLSSnapshot)
132127
map.fallthrough(DispatchKey.ADInplaceOrView)
133128
map.fallthrough(DispatchKey.BackendSelect)

0 commit comments

Comments
 (0)