[train][jax] Enable Jax trainer on GPU#58322
Merged
justinvyu merged 32 commits intoray-project:masterfrom Nov 24, 2025
Merged
Conversation
Contributor
Author
|
jax gpu image build on anyscale platform: https://gist.github.com/liulehui/bda2419e1b3245d40d8027053a8dd26c |
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
richardliaw
approved these changes
Nov 24, 2025
justinvyu
pushed a commit
that referenced
this pull request
Nov 26, 2025
1. Jax dependency is introduced in #58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
KaisennHu
pushed a commit
to KaisennHu/ray
that referenced
this pull request
Nov 26, 2025
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
aslonnie
pushed a commit
that referenced
this pull request
Nov 26, 2025
Jax dependency is introduced in #58322 The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. jax <= 0.4.14 does not support py 3.12. skip jax test if it runs against py3.12+. Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
ykdojo
pushed a commit
to ykdojo/ray
that referenced
this pull request
Nov 27, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer 2. Following Jax [GPU distributed doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if `ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS. 3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: YK <1811651+ykdojo@users.noreply.github.com>
SheldonTsen
pushed a commit
to SheldonTsen/ray
that referenced
this pull request
Dec 1, 2025
1. this PR added multihost GPU support for Ray Train JaxTrainer 2. Following Jax [GPU distributed doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if `ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS. 3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen
pushed a commit
to SheldonTsen/ray
that referenced
this pull request
Dec 1, 2025
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen
pushed a commit
to SheldonTsen/ray
that referenced
this pull request
Dec 1, 2025
Jax dependency is introduced in ray-project#58322 The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. jax <= 0.4.14 does not support py 3.12. skip jax test if it runs against py3.12+. Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
matthewdeng
pushed a commit
that referenced
this pull request
Jan 13, 2026
## Description 1. Jax dependency is introduced in #58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
rushikeshadhav
pushed a commit
to rushikeshadhav/ray
that referenced
this pull request
Jan 14, 2026
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
jeffery4011
pushed a commit
to jeffery4011/ray
that referenced
this pull request
Jan 20, 2026
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: jeffery4011 <jefferyshen1015@gmail.com>
ryanaoleary
pushed a commit
to ryanaoleary/ray
that referenced
this pull request
Feb 3, 2026
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
rayhhome
pushed a commit
to rayhhome/ray
that referenced
this pull request
Feb 4, 2026
ray-project#60593) ## Description We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for JaxTrainer, this PR is to update the corresponding docs. ## Additional information 1. tested with `make develop && make local` --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: Sirui Huang <ray.huang@anyscale.com>
elliot-barn
pushed a commit
that referenced
this pull request
Feb 9, 2026
#60593) ## Description We added GPU (#58322) and multislice TPU (#58629) support for JaxTrainer, this PR is to update the corresponding docs. ## Additional information 1. tested with `make develop && make local` --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
ans9868
pushed a commit
to ans9868/ray
that referenced
this pull request
Feb 18, 2026
ray-project#60593) ## Description We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for JaxTrainer, this PR is to update the corresponding docs. ## Additional information 1. tested with `make develop && make local` --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: Adel Nour <ans9868@nyu.edu>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
1. this PR added multihost GPU support for Ray Train JaxTrainer 2. Following Jax [GPU distributed doc](https://docs.jax.dev/en/latest/multi_process.html#gpu-example): if `ScalingConfig.use_gpu == True`, we add "cuda" as JAX_PLATFORMS. 3. if cuda is the jax platform, add CUDA_VISIBLE_DEVICES and initialize jax distributed with https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html#jax.distributed.initialize --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
Jax dependency is introduced in ray-project#58322 The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. jax <= 0.4.14 does not support py 3.12. skip jax test if it runs against py3.12+. Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
ray-project#60593) ## Description We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for JaxTrainer, this PR is to update the corresponding docs. ## Additional information 1. tested with `make develop && make local` --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
## Description 1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli
pushed a commit
to peterxcli/ray
that referenced
this pull request
Feb 25, 2026
ray-project#60593) ## Description We added GPU (ray-project#58322) and multislice TPU (ray-project#58629) support for JaxTrainer, this PR is to update the corresponding docs. ## Additional information 1. tested with `make develop && make local` --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
ScalingConfig.use_gpu == True, we add "cuda" as JAX_PLATFORMS.Related issues
Additional information