Define PJRT plugin interface in C++#6360
Conversation
2cafa8e to
3d7f48f
Compare
| } | ||
|
|
||
| std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) { | ||
| std::shared_ptr<const PjRtPlugin> GetPjRtPlugin( |
There was a problem hiding this comment.
any particular reason for this optional -> shared_ptr change?
There was a problem hiding this comment.
PjRtPlugin has to be a pointer or a reference now that we don't know the concrete type (whereas PluginEntry was just a value). My first thought was to make this an optional reference, but apparently C++ doesn't support that. An optional pointer would be unwieldy, because we'd have two layers of indirection/nullability (empty optional, and an optional holding nullptr). So I'm just returning shared_ptr here and letting nullptr represent the empty value.
| torch_xla._XLAC._register_pjrt_plugin(name, device_plugin) | ||
|
|
||
|
|
||
| def register_installed_plugins(): |
There was a problem hiding this comment.
pytorch/xla/plugins/cuda/README.md gives an example of how to use register_plugin. I wonder under what circumstance do we use register_installed_plugins() (I see it's used when we import torch_xla. Does it mean if we set XLA_REGISTER_INSTALLED_PLUGINS, then users don't have to use register_plugin anymore?)
There was a problem hiding this comment.
I don't expect users to ever call either of these themselves. As long as plugin authors set the entrypoint correctly in their package, registration will happen in the background. Otherwise, plugin authors may add register_plugin to their module and run it on import (similar to how torch.distributed backend registration works).
jonb377
left a comment
There was a problem hiding this comment.
LGTM, I like the C++ to Python interface approach.
b4a01dd to
0b96292
Compare
|
Love C++ interface too. |
PjRtPluginin C++ so the functionality can be called there. In general, implementations should still be in Python, inheriting fromDevicePlugin.PyPjRtPluginis the trampoline class that allows python implementations of virtual functions.library_pathandclient_create_optionswon't be invoked until client creation time, allowing the user to change settings after importingtorch_xla.TpuPluginthrows anEnvironmentErrorupon registration.