Skip to content

Improve device auto-detection #7730

@will-cromar

Description

@will-cromar

Device auto-detection only works for some functions.

torch_xla.device() works:

# python
>>> import torch_xla
>>> torch_xla.device()
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
device(type='xla', index=0)

torch_xla.real_devices() does not work:

# python
>>> import torch_xla
>>> torch_xla.real_devices()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/workspaces/ptxla/pytorch/xla/torch_xla/torch_xla.py", line 49, in real_devices
    return torch_xla._XLAC._xla_real_devices()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:31 : $PJRT_DEVICE is not set.

Only functions that pass through a function wrapped in requires_pjrt (or ones that call using_pjrt) trigger auto-detection; we did this to accommodate XRT, which is no longer a concern. I believe we can trigger auto detection upon import, or at least more broadly to cover our public API usage.

Tasks:

  • Remove using_pjrt and requires_pjrt. These functions are both irrelevant now, and we only (ab)use them for a side-effect
  • Remove this warning during auto-detection: WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
  • Trigger device auto-detection somewhere else that makes sense (e.g. package import). Remember that anything that references env vars will cause a graph break in Dynamo according to @JackCaoG

Metadata

Metadata

Assignees

Labels

usabilityBugs/features related to improving the usability of PyTorch/XLA

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions