Dynamic PJRT plugin registration API#5644
Conversation
6e37cbf to
3f20fde
Compare
7956ce5 to
b732806
Compare
|
Heads up @jzhoulon @aws-kingrj, I'm working on a new way for external packages to register PJRT plugins with When this API is finalized, we can move plugin registration (something like |
|
Leaving this as draft for now until I rebase after #5677, but this PR is largely ready for comments. |
| assert len(xm.get_xla_supported_devices('TPU')) > 0 | ||
|
|
||
| def test_dynamic_plugin_api(self): | ||
| with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: |
There was a problem hiding this comment.
I guess the difference b/w test_dynamic_plugin_api and test_spawn is that the former test single processing and the latter test the multi-processing?
There was a problem hiding this comment.
Yeah, I wrote test_dynamic_plugin_api before the other. I'll change the name to something like test_single_process to be more clear
|
|
||
| def register_plugin(name: str, device_plugin: DevicePlugin): | ||
| _plugin_registry[name.upper()] = device_plugin | ||
| torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) |
There was a problem hiding this comment.
I wonder what library_path we should use for GPU. iiuc, GPU doesn't involve libTPU lib.
There was a problem hiding this comment.
None right now. GPU support is statically linked in. When that moves to a plugin (say libsegpu.so), it will be the path to that binary.
|
|
||
| @staticmethod | ||
| def _assert_tpus_exist(index=0): | ||
| del index |
There was a problem hiding this comment.
index is required for spawn, but we don't need it. I just explicitly delete it to mark it unused
|
|
||
| if runtime.device_type() == 'TPU': | ||
| if plugins.using_dynamic_plugins(): | ||
| plugins.default().configure_single_process() |
There was a problem hiding this comment.
should we throw a warning or something when pople configure PJRT_DEVICE while also register the plugin in the code?
There was a problem hiding this comment.
You still select the device type with PJRT_DEVICE. Plugins will just let you register new device types when we clean up all of the hardcoded strings.
| return std::stoi(device.substr(pos + 1)); | ||
| } | ||
|
|
||
| std::unordered_map<std::string, std::string> pjrt_plugins_; |
There was a problem hiding this comment.
I thought user can only register one plug in? What;s the use case of registering multiple?
There was a problem hiding this comment.
We'll register TPU and GPU as default options, and then other packages will add plugins on top of those. JAX is also using Python entry points to register available plugins automatically, which we may also want to do.
| import torch_xla.runtime as xr | ||
| from torch_xla._internal import tpu | ||
|
|
||
| plugins.register_plugin('TPU', tpu.TpuPlugin()) |
There was a problem hiding this comment.
are you going to put this in our init file eventually?
There was a problem hiding this comment.
Yeah. I'm avoiding any changes to the default behavior while this is WIP.
2c1c483 to
8a48e90
Compare
First pass at implementing a common API for device plugins. The eventual goal is to remove any cases where we have to hard-code the device type in our build, allowing truly dynamic plugins through the PJRT plugin API.
DevicePluginAPI, including sample implementation for TPUPJRT_DYNAMIC_PLUGINS=1orplugins.use_dynamic_pluginsplugin.register_pluginand enable plugins, then you can use the same device by name by settingPJRT_DEVICE. (see integration test in this PR for an example)Future work:
XlaDeviceTypeso plugins don't have to register their device strings in this repository