Device auto-detection only works for some functions.
torch_xla.device() works:
# python
>>> import torch_xla
>>> torch_xla.device()
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
device(type='xla', index=0)
torch_xla.real_devices() does not work:
# python
>>> import torch_xla
>>> torch_xla.real_devices()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/workspaces/ptxla/pytorch/xla/torch_xla/torch_xla.py", line 49, in real_devices
return torch_xla._XLAC._xla_real_devices()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:31 : $PJRT_DEVICE is not set.
Only functions that pass through a function wrapped in requires_pjrt (or ones that call using_pjrt) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.
Tasks:
Device auto-detection only works for some functions.
torch_xla.device()works:torch_xla.real_devices()does not work:Only functions that pass through a function wrapped in
requires_pjrt(or ones that callusing_pjrt) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.Tasks:
using_pjrtandrequires_pjrt. These functions are both irrelevant now, and we only (ab)use them for a side-effectWARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md