Skip to content

[train][jax] Enable Jax trainer on GPU#58322

Merged
justinvyu merged 32 commits intoray-project:masterfrom
liulehui:jax-gpu
Nov 24, 2025
Merged

[train][jax] Enable Jax trainer on GPU#58322
justinvyu merged 32 commits intoray-project:masterfrom
liulehui:jax-gpu

Conversation

@liulehui
Copy link
Copy Markdown
Contributor

@liulehui liulehui commented Oct 30, 2025

Description

  1. this PR added multihost GPU support for Ray Train JaxTrainer
  2. Following Jax GPU distributed doc: if ScalingConfig.use_gpu == True, we add "cuda" as JAX_PLATFORMS.
  3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

Related issues

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

  1. Tested with script here: https://gist.github.com/liulehui/b0b25065d48b730f2898b712aa92e06e

@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Oct 30, 2025
@liulehui
Copy link
Copy Markdown
Contributor Author

jax gpu image build on anyscale platform: https://gist.github.com/liulehui/bda2419e1b3245d40d8027053a8dd26c

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui marked this pull request as ready for review November 22, 2025 01:58
@liulehui liulehui requested review from a team, matthewdeng and richardliaw as code owners November 22, 2025 01:58
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Nov 22, 2025
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@justinvyu justinvyu enabled auto-merge (squash) November 24, 2025 18:56
@justinvyu justinvyu merged commit b88bcc1 into ray-project:master Nov 24, 2025
7 checks passed
justinvyu pushed a commit that referenced this pull request Nov 26, 2025
1. Jax dependency is introduced in
#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
KaisennHu pushed a commit to KaisennHu/ray that referenced this pull request Nov 26, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
aslonnie pushed a commit that referenced this pull request Nov 26, 2025
Jax dependency is introduced in
#58322
The current test environment is for CUDA 12.1, which limit jax version
below 0.4.14.
jax <= 0.4.14 does not support py 3.12.
skip jax test if it runs against py3.12+.

Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
ykdojo pushed a commit to ykdojo/ray that referenced this pull request Nov 27, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer
2. Following Jax [GPU distributed
doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if
`ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS.
3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize
jax distributed with
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: YK <1811651+ykdojo@users.noreply.github.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer
2. Following Jax [GPU distributed
doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if
`ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS.
3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize
jax distributed with
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
Jax dependency is introduced in
ray-project#58322
The current test environment is for CUDA 12.1, which limit jax version
below 0.4.14.
jax <= 0.4.14 does not support py 3.12.
skip jax test if it runs against py3.12+.

Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
matthewdeng pushed a commit that referenced this pull request Jan 13, 2026
## Description
1. Jax dependency is introduced in
#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
rushikeshadhav pushed a commit to rushikeshadhav/ray that referenced this pull request Jan 14, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
jeffery4011 pushed a commit to jeffery4011/ray that referenced this pull request Jan 20, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: jeffery4011 <jefferyshen1015@gmail.com>
matthewdeng pushed a commit that referenced this pull request Feb 3, 2026
#60593)

## Description

We added GPU (#58322) and multislice TPU (#58629) support for
JaxTrainer, this PR is to update the corresponding docs.



## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
ryanaoleary pushed a commit to ryanaoleary/ray that referenced this pull request Feb 3, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
rayhhome pushed a commit to rayhhome/ray that referenced this pull request Feb 4, 2026
ray-project#60593)

## Description

We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for
JaxTrainer, this PR is to update the corresponding docs.

## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Sirui Huang <ray.huang@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Feb 9, 2026
#60593)

## Description

We added GPU (#58322) and multislice TPU (#58629) support for
JaxTrainer, this PR is to update the corresponding docs.



## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Feb 9, 2026
#60593)

## Description

We added GPU (#58322) and multislice TPU (#58629) support for
JaxTrainer, this PR is to update the corresponding docs.



## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
ans9868 pushed a commit to ans9868/ray that referenced this pull request Feb 18, 2026
ray-project#60593)

## Description

We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for
JaxTrainer, this PR is to update the corresponding docs.

## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Adel Nour <ans9868@nyu.edu>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
1. this PR added multihost GPU support for Ray Train JaxTrainer
2. Following Jax [GPU distributed
doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if
`ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS.
3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize
jax distributed with
https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
Jax dependency is introduced in
ray-project#58322
The current test environment is for CUDA 12.1, which limit jax version
below 0.4.14.
jax <= 0.4.14 does not support py 3.12.
skip jax test if it runs against py3.12+.

Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
ray-project#60593)

## Description

We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for
JaxTrainer, this PR is to update the corresponding docs.

## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
## Description
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
ray-project#60593)

## Description

We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for
JaxTrainer, this PR is to update the corresponding docs.

## Additional information
1. tested with `make develop && make local`

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants