Skip to content

Commit f7a0000

Browse files
committed
Update on "Rewrite fallthrough to more closely match how C++ works"
Fallthrough is modeled as a mask which we use to remove keys from the compute dispatch key set for eligibility. It's possible this addresses #89037 in a better way than #95891 but I cannot easily tell as the original repro no longer works and the new PR does not have a test. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
2 parents 1249f43 + 6c52c18 commit f7a0000

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

torch/_C/__init__.pyi.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,12 +1325,15 @@ class DispatchKeySet:
13251325
def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
13261326
def highestPriorityTypeId(self) -> DispatchKey: ...
13271327
def has(self, k: _dispatchkey) -> _bool: ...
1328+
def add(self, k: _dispatchkey) -> DispatchKeySet: ...
1329+
def remove(self, k: _dispatchkey) -> DispatchKeySet: ...
13281330
def __repr__(self) -> str: ...
13291331

13301332
_dispatch_autogradother_backends: DispatchKeySet
13311333

13321334
def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
13331335
def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
1336+
def _dispatch_keyset_full() -> DispatchKeySet: ...
13341337
def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
13351338
def _dispatch_get_backend_keyset_from_autograd(
13361339
dispatch: _dispatchkey,

torch/_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import sys
55
import types
6-
from typing import Any, Dict, Type
6+
from typing import Any, Callable, Dict, Type, Union
77

88
import torch._C
99

@@ -82,7 +82,9 @@ def __init__(self):
8282
# thought of as an open world extension of dispatch keys, so it
8383
# makes sense that you should be able to register them, the same
8484
# way you can register dispatch keys.
85-
self.python_key_mode_table: Dict[Type[TorchDispatchMode]] = {}
85+
self.python_key_mode_table: Dict[
86+
Type[TorchDispatchMode], Callable[..., Any]
87+
] = {}
8688

8789
# This table allows you to override the behavior of functorch
8890
# transformations. NB: this currently only does something for
@@ -115,7 +117,7 @@ def inner(fn):
115117

116118
if k in self.py_kernels:
117119
raise RuntimeError(
118-
f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}"
120+
f"Trying to override a python impl for {k} on operator {self.name()}"
119121
)
120122
self.py_kernels[k] = fn
121123
self._dispatch_cache.clear()

0 commit comments

Comments
 (0)