-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[Train] JaxTrainer Implementation Tracking Issue #55162
Copy link
Copy link
Open
Labels
community-backlogenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilityperformancetrainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability
Description
Description
This issue will serve as an implementation tracker for a JaxTrainer to support jax and SPMD workloads. The initial support for this framework will target SPMD with multi-host TPUs on Kubernetes.
Milestone 1: MVP of JaxTrainer with single-slice multi-host TPUs
- Add default TPU info to Ray node labels
- Add API change to
ScalingConfigto specifytopologyandacceleratorarguments - Add JaxTrainer wrapper of DataParallelTrainer to RayTrain with SPMD scheduling support
- Extensively test Jax training workload with multi-host TPUs and Anyscale and KubeRay operators
- Add documentation and guides to RayTrain docs
Milestone 2: Full support for TPU multi-slice
- Implemented in: [Train] Add TPU multi-slice support to JaxTrainer #58629
Milestone 3: Elastic training with multi-slice TPUs in Ray Train
- Detect Alive TPU slices: [Core] Add TPU util to determine number of ready multi-host slices #61300
- Update elastic policy to scale for TPU slices: [Train] Update elastic policy to handle multi-host TPUs with JaxTrainer #61299
Use case
Full support for SPMD workloads orchestrated with Ray and RayTrain.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
community-backlogenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilityperformancetrainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability
Type
Projects
Status
Todo