Skip to content

Commit 71d5202

Browse files
committed
[POC] Don't decompose matmul
Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
1 parent 9ca41a9 commit 71d5202

3 files changed

Lines changed: 33 additions & 0 deletions

File tree

a.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from torch.fx.experimental.proxy_tensor import make_fx
3+
from torch._dispatch.python import enable_python_dispatcher, patch_py_impls
4+
5+
with enable_python_dispatcher(), patch_py_impls({
6+
torch.ops.aten.matmul.default: {torch._C.DispatchKey.AutogradCPU: torch._C.DispatchKey.Autograd}
7+
}):
8+
print(make_fx(torch.matmul)(torch.randn(2, 3), torch.randn(3, 4)))

torch/_dispatch/python.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.utils._pytree as pytree
66
import itertools
7+
from typing import Callable, Dict
78

89
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
910

@@ -140,3 +141,26 @@ def enable_crossref_functionalize():
140141
finally:
141142
for op in all_known_overloads():
142143
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
144+
145+
@contextmanager
146+
def patch_py_impls(all_patches: Dict[torch._ops.OpOverload, Dict[torch._C.DispatchKey, Callable]]):
147+
"""
148+
Temporarily patch the dispatcher registrations in the Python Dispatcher,
149+
undoing them when you exit the context manager. This is useful for
150+
temporarily adding pre-autograd decompositions, among other things.
151+
"""
152+
saved_tables = {}
153+
for op, patches in all_patches.items():
154+
# TODO: Make this public API on OpOverload instead
155+
# of groveling the attribute directly
156+
saved_tables[op] = op.py_kernels.copy()
157+
for k, fn in patches.items():
158+
op.py_impl(k)(fn)
159+
try:
160+
yield
161+
finally:
162+
for op in all_patches:
163+
# TODO: Make this OpOverload API
164+
op.py_kernels.clear()
165+
op.py_kernels.update(saved_tables[op])
166+
op._dispatch_cache.clear()

torch/_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def namespace(self):
295295
return self._schema.name.split("::")[0]
296296

297297
def decompose(self, *args, **kwargs):
298+
return NotImplemented
298299
dk = torch._C.DispatchKey.CompositeImplicitAutograd
299300
if dk in self.py_kernels:
300301
# NB: This branch is not too necessary anymore, because we can

0 commit comments

Comments
 (0)