Add basic device APIs to the top-level torch_xla module.#6571
Add basic device APIs to the top-level torch_xla module.#6571will-cromar merged 6 commits intomasterfrom
torch_xla module.#6571Conversation
|
|
||
| std::vector<std::string> xla_devices; | ||
| { | ||
| NoGilSection nogil; |
There was a problem hiding this comment.
Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as GetXlaDevices(*devices)?
There was a problem hiding this comment.
I'm not sure to be honest. Maybe GetXlaDevices was a blocking call in XRT?
There was a problem hiding this comment.
oh lol it is that torch_xla
| def setUpClass(): | ||
| xr.set_device_type('CPU') | ||
| os.environ['CPU_NUM_DEVICES'] = '4' |
There was a problem hiding this comment.
hmm, it is OK for now but shouldn't we also test it on GPU and TPU?
There was a problem hiding this comment.
This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.
As we switch to using these functions by convention, they'll be exercised by almost every other test.
| plugins.use_dynamic_plugins() | ||
| plugins.register_installed_plugins() | ||
|
|
||
| from .torch_xla import * |
There was a problem hiding this comment.
This imports the contents of torch_xla.py into torch_xla/'s module scope. Otherwise, the functions would be torch_xla.torch_xla.etc. This assigns them to torch_xla.etc
Implemented basic device APIs from #6399.
device,devices,real_devices, anddevice_countas described in RFC.__init__.py. Addtorch_xla.pyfor public functions ontorch_xlamodule.Follow up:
xm.xla_deviceandruntime.xla_device.