Conversation
|
I think what's happening with torch_xla.devices() has 4 devices but jax.devices() only have 1. One possibility is that jax[cuda] was not installed so jax.devices() returnts one device and that is the CPU device. Last time I tried to add the install and hit a different error. I am OK with disabling the test for CUDA until later. |
That is a great point! This made me realize that if we land this PR as-is, then not only will I think I'll split out the call_jax part of this PR into a separate one, and unfortunately that one can't be landed unless we fix the PJRT client sharing between PyTorch/XLA and JAX. |
8e682cd to
d83ffca
Compare
8b110a4 to
debf063
Compare
|
Draft for registering Jax contextual mesh in call_jax: #9043 (informational only) |
Beefed up the assume_pure tests and updated the docs to mention that mark_sharding is supported thanks to qihqi@'s #8989.
Also update yapf in the dev image to match CI.