Skip to content

Commit cea298f

Browse files
committed
cleanup
1 parent dbf18e3 commit cea298f

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

torch/_ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def inner(fn):
150150

151151
dispatch_key = dispatch_key_or_mode_or_transform
152152
assert (
153-
dispatch_key != DispatchKey.Python
153+
dispatch_key != torch._C.DispatchKey.Python
154154
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
155155
assert isinstance(dispatch_key, torch._C.DispatchKey)
156156
assert dispatch_key not in self.table
@@ -162,10 +162,10 @@ def inner(fn):
162162
def dispatch(self, dispatch_key, *args, **kwargs):
163163
from torch.utils._python_dispatch import _get_current_dispatch_mode
164164

165-
if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
165+
if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
166166
return dispatch_functorch(self, args, kwargs)
167167

168-
if dispatch_key == DispatchKey.Python:
168+
if dispatch_key == torch._C.DispatchKey.Python:
169169
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
170170
curr_mode = _get_current_dispatch_mode()
171171
assert (
@@ -205,8 +205,9 @@ def inner(*args, **kwargs):
205205
- self.fallthrough_keys
206206
- DispatchKeySet(dispatch_key)
207207
)
208-
highest_key = all_keys_after_current_masked.highestPriorityTypeId()
209-
return self.dispatch(highest_key, *args, **kwargs)
208+
return self.dispatch(
209+
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
210+
)
210211

211212
return inner
212213

@@ -225,9 +226,9 @@ def inner(*args, **kwargs):
225226
all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
226227
args, kwargs
227228
)
228-
highest_key = all_keys_after_current_masked.highestPriorityTypeId()
229-
230-
return self.dispatch(highest_key, *args, **kwargs)
229+
return self.dispatch(
230+
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
231+
)
231232

232233
return inner
233234

0 commit comments

Comments
 (0)