@@ -150,7 +150,7 @@ def inner(fn):
150150
151151 dispatch_key = dispatch_key_or_mode_or_transform
152152 assert (
153- dispatch_key != DispatchKey .Python
153+ dispatch_key != torch . _C . DispatchKey .Python
154154 ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
155155 assert isinstance (dispatch_key , torch ._C .DispatchKey )
156156 assert dispatch_key not in self .table
@@ -162,10 +162,10 @@ def inner(fn):
162162 def dispatch (self , dispatch_key , * args , ** kwargs ):
163163 from torch .utils ._python_dispatch import _get_current_dispatch_mode
164164
165- if dispatch_key == DispatchKey .FuncTorchDynamicLayerFrontMode :
165+ if dispatch_key == torch . _C . DispatchKey .FuncTorchDynamicLayerFrontMode :
166166 return dispatch_functorch (self , args , kwargs )
167167
168- if dispatch_key == DispatchKey .Python :
168+ if dispatch_key == torch . _C . DispatchKey .Python :
169169 # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
170170 curr_mode = _get_current_dispatch_mode ()
171171 assert (
@@ -205,8 +205,9 @@ def inner(*args, **kwargs):
205205 - self .fallthrough_keys
206206 - DispatchKeySet (dispatch_key )
207207 )
208- highest_key = all_keys_after_current_masked .highestPriorityTypeId ()
209- return self .dispatch (highest_key , * args , ** kwargs )
208+ return self .dispatch (
209+ all_keys_after_current_masked .highestPriorityTypeId (), * args , ** kwargs
210+ )
210211
211212 return inner
212213
@@ -225,9 +226,9 @@ def inner(*args, **kwargs):
225226 all_keys_after_current_masked = all_keys_after_current & _compute_keyset (
226227 args , kwargs
227228 )
228- highest_key = all_keys_after_current_masked . highestPriorityTypeId ()
229-
230- return self . dispatch ( highest_key , * args , ** kwargs )
229+ return self . dispatch (
230+ all_keys_after_current_masked . highestPriorityTypeId (), * args , ** kwargs
231+ )
231232
232233 return inner
233234
0 commit comments