Add utility to get computed kernel in torch.library#158393
Add utility to get computed kernel in torch.library#158393mikaylagawarecki wants to merge 13 commits intogh/mikaylagawarecki/320/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158393
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 172857e with merge base 34ec5ed ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 cc albanD [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Sounds pretty good!
Only small questions!
torch/library.py
Outdated
| op = op._name | ||
|
|
||
| if isinstance(dispatch_key, str): | ||
| dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] |
There was a problem hiding this comment.
What is the error you get when passing a wrong dispatch key here?
There was a problem hiding this comment.
oops, added proper error handling here
| auto [annotatedKernel, _] = | ||
| computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); | ||
|
|
||
| return SafeKernelFunction(&annotatedKernel.kernel); |
There was a problem hiding this comment.
I think it would be nice to grab the debug string here and add that to the __repr__ we get from python?
There was a problem hiding this comment.
This gives something like
SafeKernelFunction(debug='registered at /data/users/mg1998/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:2309')
Do you think that is meaningful enough or should we add more info
There was a problem hiding this comment.
Yes, it is 100% super useful. These error messages saved me a few times for multiple-registration errors!
And for python users, it should point to their code directly. Which is even better so they know which function this is!
|
|
||
| // List of tokens that need to be invalidated when this KernelFunction is | ||
| // destroyed | ||
| mutable std::vector<std::weak_ptr<KernelToken>> tokens_; |
There was a problem hiding this comment.
Also why weak_ptr and not shared_ptr?
There was a problem hiding this comment.
why mutable
I think this is necessary in order to make registerToken const, which was in turn needed to allow SafeKernelFunction to take in const KernelFunction*, removing this would necessitate const_cast-ing the annotatedKernel.kernel in getComputedKernelForDispatchKey, wdyt
why weak_ptr and not shared_ptr
What's the benefit of shared_ptr over weak_ptr here? if we use weak_ptr, the KernelToken dies with the SafeKernelFunction, which I think is what we want to achieve here (?)
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
|
@pytorchbot merge -r |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 Pull Request resolved: #158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 Pull Request resolved: #158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to pytorch#155330 Pull Request resolved: pytorch#158393 Approved by: https://github.com/albanD
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to pytorch#155330 Pull Request resolved: pytorch#158393 Approved by: https://github.com/albanD
Adds
OperatorEntry::getComputedKernelForDispatchKeywhich returns the KernelFunction corresponding toOperatorEntry.dispatchTable_[dispatch_ix]for a given dispatch keySafeKernelFunctionthat holds aKernelToken. ThisKernelTokenis registered to theKernelFunctioninOperatorEntry.kernels_and will be invalidated when theKernelFunctionis destructed (i.e. when theAnnotatedKernelthat holds thisKernelFunctionis removed fromkernels_, which happens when the corresponding impl is deregistered).SafeKernelFunctioncan be called viacallBoxed, the validity of the token will be checked before this happensSafeKernelFunctionis pybinded andgetComputedKernelForDispatchKeyis exposed to the frontend iatorch.library.get_kernelRelated to #155330
Stack from ghstack (oldest at bottom):
cc @albanD