Skip to content

Remove dependency on JAX #9494

@jeffhataws

Description

@jeffhataws

🐛 Bug

In v2.8, the dependency on JAX is added as part of pip dependency in https://github.com/pytorch/xla/blob/master/setup.py#L121. This causes JAX to be installed in user environment even if they are not using torchax or JAX. Additionally we now see the error #9243 .

Also, current_accelerator() now returns device(type='jax'), instead of device(type='cuda') causing unexpected change in some behavior such as in parallel loader where pin_memory now must be set to False to work as before.

Finally, due to this dependency, torch-xla now also drops python 3.10 support even though it is supported until Oct 2026. It's best to bring back this support for customers who still use Ubuntu 22.

To improve customer experience, please make this dependency optional and maybe only a dependency of torchax instead.

To Reproduce

Install torch-xla 2.8 release candidate, and do pipdeptree to get the dependency list, and observe that jax is part of the torch-xla dependency tree.

Expected behavior

JAX is not installed by default with torch-xla

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: any
  • torch_xla version: 2.8

Additional context

Metadata

Metadata

Assignees

Labels

installPyTorch/XLA installation related issues.usabilityBugs/features related to improving the usability of PyTorch/XLA

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions