File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ import torch
2+ from torch .fx .experimental .proxy_tensor import make_fx
3+ from torch ._dispatch .python import enable_python_dispatcher , patch_py_impls
4+
5+ with enable_python_dispatcher (), patch_py_impls ({
6+ torch .ops .aten .matmul .default : {torch ._C .DispatchKey .AutogradCPU : torch ._C .DispatchKey .Autograd }
7+ }):
8+ print (make_fx (torch .matmul )(torch .randn (2 , 3 ), torch .randn (3 , 4 )))
Original file line number Diff line number Diff line change 44import torch
55import torch .utils ._pytree as pytree
66import itertools
7+ from typing import Callable , Dict
78
89__all__ = ['enable_python_dispatcher' , 'no_python_dispatcher' ]
910
@@ -140,3 +141,26 @@ def enable_crossref_functionalize():
140141 finally :
141142 for op in all_known_overloads ():
142143 op ._uncache_dispatch (torch ._C .DispatchKey .Functionalize )
144+
145+ @contextmanager
146+ def patch_py_impls (all_patches : Dict [torch ._ops .OpOverload , Dict [torch ._C .DispatchKey , Callable ]]):
147+ """
148+ Temporarily patch the dispatcher registrations in the Python Dispatcher,
149+ undoing them when you exit the context manager. This is useful for
150+ temporarily adding pre-autograd decompositions, among other things.
151+ """
152+ saved_tables = {}
153+ for op , patches in all_patches .items ():
154+ # TODO: Make this public API on OpOverload instead
155+ # of groveling the attribute directly
156+ saved_tables [op ] = op .py_kernels .copy ()
157+ for k , fn in patches .items ():
158+ op .py_impl (k )(fn )
159+ try :
160+ yield
161+ finally :
162+ for op in all_patches :
163+ # TODO: Make this OpOverload API
164+ op .py_kernels .clear ()
165+ op .py_kernels .update (saved_tables [op ])
166+ op ._dispatch_cache .clear ()
Original file line number Diff line number Diff line change @@ -295,6 +295,7 @@ def namespace(self):
295295 return self ._schema .name .split ("::" )[0 ]
296296
297297 def decompose (self , * args , ** kwargs ):
298+ return NotImplemented
298299 dk = torch ._C .DispatchKey .CompositeImplicitAutograd
299300 if dk in self .py_kernels :
300301 # NB: This branch is not too necessary anymore, because we can
You can’t perform that action at this time.
0 commit comments