Skip to content

Move mutable properties of env to thread local, misc changes#9501

Merged
qihqi merged 11 commits intomasterfrom
hanq_torchax
Jul 24, 2025
Merged

Move mutable properties of env to thread local, misc changes#9501
qihqi merged 11 commits intomasterfrom
hanq_torchax

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Jul 23, 2025

  1. refactored a bit of jax default device. Now it respects whatever user sets with jax.default_device
In [9]: with jax.default_device(jax.devices("cpu")[0]):
   ...:     a = torch.randn((2,2), device='jax')
   ...:

In [10]: a._elem.device
Out[10]: CpuDevice(id=0)

In [11]: with jax.default_device(jax.devices("cpu")[0]):
    ...:     a = torch.randn((2,2)).to(device='jax')
    ...:

In [12]: a
Out[12]:
Tensor(<class 'jaxlib._jax.ArrayImpl'> [[0.7494194  0.94113636]
 [0.18599068 0.8405661 ]])

In [13]: a._elem.device
Out[13]: CpuDevice(id=0)

In [14]: with jax.default_device(jax.devices("tpu")[0]):
    ...:     a = torch.randn((2,2)).to(device='jax')
    ...:

In [15]: a._elem.device
Out[15]: TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)

@qihqi qihqi requested a review from zzzwen July 23, 2025 22:36
@qihqi qihqi marked this pull request as ready for review July 23, 2025 22:36
Copy link
Copy Markdown
Collaborator

@zzzwen zzzwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to remove the UT relating to jax_device as well

Comment thread torchax/torchax/tensor.py
Comment thread torchax/torchax/tensor.py Outdated
Comment thread torchax/torchax/tensor.py
@qihqi qihqi enabled auto-merge (squash) July 24, 2025 04:15
@qihqi qihqi merged commit 0a1594a into master Jul 24, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants