Skip to content

Add basic device APIs to the top-level torch_xla module.#6571

Merged
will-cromar merged 6 commits intomasterfrom
wcromar/torch-xla-device
Feb 21, 2024
Merged

Add basic device APIs to the top-level torch_xla module.#6571
will-cromar merged 6 commits intomasterfrom
wcromar/torch-xla-device

Conversation

@will-cromar
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar commented Feb 20, 2024

Implemented basic device APIs from #6399.

  • Add device, devices, real_devices, and device_count as described in RFC.
  • We already do a substantial amount of setup in __init__.py . Add torch_xla.py for public functions on torch_xla module.

Follow up:

  • Update documentation to use new APIs.
  • Start deprecating or discouraging usage of old APIs like xm.xla_device and runtime.xla_device.

@will-cromar will-cromar added runtime usability Bugs/features related to improving the usability of PyTorch/XLA labels Feb 20, 2024
@will-cromar will-cromar requested a review from JackCaoG February 20, 2024 22:05
@will-cromar will-cromar marked this pull request as ready for review February 20, 2024 22:05

std::vector<std::string> xla_devices;
{
NoGilSection nogil;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as GetXlaDevices(*devices)?

Copy link
Copy Markdown
Collaborator Author

@will-cromar will-cromar Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to be honest. Maybe GetXlaDevices was a blocking call in XRT?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh lol it is that torch_xla

Comment thread test/test_devices.py
Comment on lines +11 to +13
def setUpClass():
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is OK for now but shouldn't we also test it on GPU and TPU?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.

As we switch to using these functions by convention, they'll be exercised by almost every other test.

Comment thread torch_xla/__init__.py
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

from .torch_xla import *
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This imports the contents of torch_xla.py into torch_xla/'s module scope. Otherwise, the functions would be torch_xla.torch_xla.etc. This assigns them to torch_xla.etc

@will-cromar will-cromar merged commit 0ec5b91 into master Feb 21, 2024
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

runtime usability Bugs/features related to improving the usability of PyTorch/XLA

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants