Make Tensor's __dlpack__ and __dlpack_device__ account for XLA.#128176
Make Tensor's __dlpack__ and __dlpack_device__ account for XLA.#128176vanbasten23 wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128176
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 01c4a1f with merge base fa8c343 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@emcastillo, @rgommers, @Mulberry are you the right person to review this PR? Appreciate the suggestion. |
|
Hi @albanD @Skylion007 , could you help review this PR? Thanks! |
|
Gentle ping @albanD @Skylion007 |
|
I'm not sure what this environment variable is expected to do cc @JackCaoG |
There was a problem hiding this comment.
if torch_device_type is already xla, I think it is better to import torch_xla and check runtime time instead of checking the env var.
7c405b1 to
01c4a1f
Compare
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Taking over: #128176. In summary, this PR: - `__dlpack__`: Calls PyTorch/XLA `to_dlpack` function, if the tensor lives in an XLA:CUDA device - `__dlpack_device__`: Correctly maps PyTorch/XLA tensors to `kDLGPU`, if XLA:CUDA is being used The tests are introduced in pytorch/xla#7213. Pull Request resolved: #138470 Approved by: https://github.com/albanD Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Taking over: #128176. In summary, this PR: - `__dlpack__`: Calls PyTorch/XLA `to_dlpack` function, if the tensor lives in an XLA:CUDA device - `__dlpack_device__`: Correctly maps PyTorch/XLA tensors to `kDLGPU`, if XLA:CUDA is being used The tests are introduced in pytorch/xla#7213. Pull Request resolved: #138470 Approved by: https://github.com/albanD Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Fixes #ISSUE_NUMBER
XLA tensor is a Tensor. This PR
__dlpack_device__account for XLA device when XLA GPU is used.__dlpack__call the torch_xla version ofto_dlpack.A corresponding test is added in pytorch/xla#7213