Skip to content

[Train] Update elastic policy to handle multi-host TPUs with JaxTrainer#61299

Merged
matthewdeng merged 23 commits intoray-project:masterfrom
ryanaoleary:elastic-train-tpu
Mar 17, 2026
Merged

[Train] Update elastic policy to handle multi-host TPUs with JaxTrainer#61299
matthewdeng merged 23 commits intoray-project:masterfrom
ryanaoleary:elastic-train-tpu

Conversation

@ryanaoleary
Copy link
Copy Markdown
Contributor

@ryanaoleary ryanaoleary commented Feb 25, 2026

Description

This PR implements support for elastic training on TPUs using the JaxTrainer API and the elastic scaling policy.

Specifically, this PR utilizes a new TPU utility get_num_ready_tpu_slices to return the number of full, ready TPU slices in the RayCluster and then adjusts the _count_possible_workers calculation when running on TPUs to scale atomically by TPU slices. This PR also adds comprehensive unit tests and an e2e test for the new support.

I'll separate the ray.util.tpu change in a separate PR, but left it in for now so that the tests could pass.

Related issues

Implements milestone 3 of #55162

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary requested review from a team as code owners February 25, 2026 01:09
@ryanaoleary
Copy link
Copy Markdown
Contributor Author

Relevant output from e2e test:

Scaling policy waits until a full slice is available:

--------------------------------------------------------------------------------
[elapsed=3.1s] cluster_resources={'TPU': 4.0, 'CPU': 8.0, 'TPU-v6e-8-head': 1.0, 'accelerator_type:TPU-V6E': 1.0}
Successfully registered 4 TPUs in cluster.
--------------------------------------------------------------------------------

(TrainController pid=2730471) Requesting resources to fit the maximum number of workers: {'TPU': 4, 'CPU': 1, 'accelerator_type:TPU-V6E': 0.001} * 6
(TrainController pid=2730471) Detected ready resources for 0 workers in the cluster. Deciding NOT to start/restart training due to the number of workers falling below the minimum (min_workers=2).

Once a complete slice is available, controller cleanly upscales:

--------------------------------------------------------------------------------
[elapsed=22.5s] cluster_resources={'TPU': 16.0, 'CPU': 32.0, 'TPU-v6e-8-head': 2.0, 'accelerator_type:TPU-V6E': 4.0}
Successfully registered 16 TPUs in cluster.
--------------------------------------------------------------------------------

(TrainController pid=2730471) Detected ready resources for 4 workers in the cluster. Deciding to start/restart training with this worker group size.
(TrainController pid=2730471) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '2x4'...
(TrainController pid=2730471) Attempting to start training worker group of size 4 with the following resources: [{'TPU': 4, 'CPU': 1, 'accelerator_type:TPU-V6E': 0.001}] * 4
(TrainController pid=2730471) Started training worker group of size 4: 
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2732044) world_rank=0, local_rank=0, node_rank=0
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2731722) world_rank=1, local_rank=1, node_rank=0
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2731914) world_rank=2, local_rank=2, node_rank=0
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2732043) world_rank=3, local_rank=3, node_rank=0

When a TPU worker is preempted, the policy recovers and scales down a slice:

(TrainController pid=2730471) The actor died because its node has died. Node Id: dfac8b3572abfdebd0bbabee8afb53c5ad07dd88c2e020986682637a
(TrainController pid=2730471)    the actor's node was terminated expectedly: received SIGTERM
(TrainController pid=2730471) Detected ready resources for 2 workers in the cluster. Deciding to start/restart training with this worker group size.
(TrainController pid=2730471) Error during JAX distributed shutdown: The actor died unexpectedly before finishing this task.
(TrainController pid=2730471)    class_name: RayTrainWorker
(TrainController pid=2730471)    actor_id: 7572726017b152d60d7b0acf01000000
(TrainController pid=2730471)    pid: 2731914
(TrainController pid=2730471)    namespace: 0b45f7ad-7d93-4c55-b65f-ef0d39ebbd47
(TrainController pid=2730471)    ip: 172.19.2.115
(TrainController pid=2730471) The actor died because its node has died. Node Id: 5fa99c67d9c9483ad0070a17e05850044971db82b1dd72a9de1348f4
(TrainController pid=2730471)    the actor's node was terminated expectedly: received SIGTERM
(TrainController pid=2730471) Using SlicePlacementGroup utility to reserve 1 slice(s) with topology '2x4'...
(TrainController pid=2730471) Attempting to start training worker group of size 2 with the following resources: [{'TPU': 4, 'CPU': 1, 'accelerator_type:TPU-V6E': 0.001}] * 2
(TrainController pid=2730471) Started training worker group of size 2: 
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2732536) world_rank=0, local_rank=0, node_rank=0
(TrainController pid=2730471) - (ip=172.19.2.115, pid=2732541) world_rank=1, local_rank=1, node_rank=0

Test completes:

--------------------------------------------------------------------------------
[elapsed=100.9s] cluster_resources={'TPU': 8.0, 'CPU': 16.0, 'TPU-v6e-8-head': 1.0, 'accelerator_type:TPU-V6E': 2.0}
Training finished with result: Result(metrics={'epoch': 30, 'world_size': 2, 'min_world_size': 2, 'max_world_size': 4}, checkpoint=Checkpoint(filesystem=local, path=/tmp/pytest-of-ryanaoleary/pytest-54/test_elastic_training_tpu0/ray_train_run-2026-02-25_00-59-29/checkpoint-epoch=30), error=None, path='/tmp/pytest-of-ryanaoleary/pytest-54/test_elastic_training_tpu0/ray_train_run-2026-02-25_00-59-29', metrics_dataframe=   epoch  world_size  min_world_size  max_world_size
0     29           2               2               4
1     30           2               2               4, best_checkpoints=[(Checkpoint(filesystem=local, path=/tmp/pytest-of-ryanaoleary/pytest-54/test_elastic_training_tpu0/ray_train_run-2026-02-25_00-59-29/checkpoint-epoch=29), {'epoch': 29, 'world_size': 2, 'min_world_size': 2, 'max_world_size': 4}), (Checkpoint(filesystem=local, path=/tmp/pytest-of-ryanaoleary/pytest-54/test_elastic_training_tpu0/ray_train_run-2026-02-25_00-59-29/checkpoint-epoch=30), {'epoch': 30, 'world_size': 2, 'min_world_size': 2, 'max_world_size': 4})], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7fa9d4173e70>)
--------------------------------------------------------------------------------

@ryanaoleary
Copy link
Copy Markdown
Contributor Author

cc: @liulehui @matthewdeng

@ryanaoleary
Copy link
Copy Markdown
Contributor Author

cc: @siyuanfoundation for help testing with guide showcasing Orbax checkpointing with this new support (i.e. what was being implemented in #60759)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for elastic training on multi-host TPUs with the JaxTrainer. The core changes involve a new utility get_num_ready_tpu_slices to accurately count available TPU slices and modifications to the elastic scaling policy to scale atomically based on these slices. The implementation appears robust and is well-supported by comprehensive unit tests and a new end-to-end test that simulates various scaling scenarios. The changes are clear and address the requirements for elastic TPU training effectively. I have one suggestion to improve the e2e test's implementation to avoid modifying global state.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
@ryanaoleary
Copy link
Copy Markdown
Contributor Author

A future PR may need to extend support for requesting labels to the AutoscalingCoordinator https://github.com/ray-project/ray/blob/master/python/ray/data/_internal/cluster_autoscaler/default_autoscaling_coordinator.py. We currently require the TPU utility to verify the available nodes have the required topology labels are are available, but if the AutoscalingCoordinator natively understood labels (currently it only checks resource counts) we could simplify the TPU logic by passing the desired labels when the elastic policy is making scaling decisions.

Copy link
Copy Markdown
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

total_num_workers = min(int(total_num_workers), self.scaling_config.max_workers)

# Multi-host TPUs are scheduled atomically in interconnected slices defined by a topology.
# Floor the total available workers to the nearest multiple of the slice size.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer throw error here instead of flooring it. can we validate the min_workers/max_workers in the _validate_tpu_config: https://github.com/ray-project/ray/blob/master/python/ray/train/v2/api/config.py#L170

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already are validating min_workers and max_workers here:

if workers_per_slice > 0 and max_workers % workers_per_slice != 0:
. I think we need the flooring logic here because otherwise the policy might try to upscale when we have an uneven number of TPUs (i.e. min_workers is 4 and resources for 7 TPU nodes are available -> we want the policy to upscale 4 workers and then wait until resources for 1 additional worker are available to scale up to 8). Without the flooring logic currently, even if min_workers and max_workers are validated correctly the policy will try to upscale whenever we have available resources for > min_workers, even if the resources available are not a full TPU slice.

and self.scaling_config.topology
and self.scaling_config.accelerator_type
):
from ray.util.tpu import get_num_ready_tpu_slices, get_tpu_worker_resources
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems for calculating possible tpu workers, we are not using the allocated_resources, should we instead of putting this maybe _count_possible_tpu_workers instead?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second thought, I think we still wanna use allocated_resources, it is possible that in a cluster, the "ALIVE" TPU nodes are used by both data loading and training, which is the idea of the autoscaling_coordinator to balance/prioritize the resource within the cluster. so it might not be correct by checking the alive node within the ray cluster for training to get num_ready_tpu_slices.

Copy link
Copy Markdown
Contributor Author

@ryanaoleary ryanaoleary Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we are using allocated_resources because we take the count calculated using that value and divide it to get the number of available TPU slices:

num_complete_slices = total_num_workers // workers_per_slice

I think you're right that num_ready_tpu_slices is only returning the Alive slices, which doesn't necessarily mean they're available since they could be used for Data loading, etc. but we're already accounting for this by calculating:

num_complete_slices = min(num_complete_slices, num_ready_slices)

So we use the allocated_resources count to determine how many "free" TPU slices we have, relying on the autoscaler as the source of truth, and then cap that value by the number of "Alive" TPU slices that are labeled as part of the same ray.io/tpu-slice-name slice. I think I could make the var names more clear to describe what's happening.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suppose 4 slices (slice 1-4) are allocated to task1, and 2 slices (slice 5-6) allocated to task2, each slice has 4 hosts; then suppose slice 1-4 each have 1 dead node.

num_complete_slices = total_num_workers // workers_per_slice = 4*3/4=3

num_ready_slices will be the 2 healthy slices of task2.
In this case, wouldn't num_complete_slices = min(num_complete_slices, num_ready_slices)=2 but actually it should be 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're correct that it should evaluate to 0. I think the current code works in practice because TPU pods with KubeRay will be scaled down atomically, so the surviving orphaned nodes in Slices 1-4 will be rapidly terminated. Even during the brief window before that atomic teardown finishes, GCS enforces the bundle_label_selector set in the SlicePlacementGroups it creates for the 2 slices, so the placement request remains in a PENDING state even when there are fractured slices until the whole slice is healthy and available.

Even though the current solution might work I agree it's hacky and it'd be better if the utility returned slices that are both complete, and available (rather than just feasible). I didn't previously because I was worried about race conditions or inconsistencies in resource availability if the same workers were temporarily used for other tasks like data loading (i.e. we'd still want the PG to be created, but just remain pending until the other Task was done and it could be scheduled).

Copy link
Copy Markdown
Contributor Author

@ryanaoleary ryanaoleary Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in d266c8c, the helper now relies on the State API to get node resource usage.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it is pretty robust now.

@ryanaoleary ryanaoleary requested a review from liulehui February 26, 2026 21:55
ryanaoleary and others added 5 commits February 27, 2026 01:09
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

@siyuanfoundation
Copy link
Copy Markdown
Contributor

/lgtm

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
)

if failure_decision == FailureDecision.RETRY:
if self._worker_group:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liulehui this is what ended up fixing the issue we talked about offline. Without this line to shutdown the worker group, the elastic policy would stall on the resize decision after killing 1 node. This was happening because before it releases the placement group and shuts down the worker group, the policy calculates the number of possible workers for the resize decision - but since our TPU logic checks for available slices and the slices are still held by the worker group, it returned resources for 0 workers.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see.

In the current state transition flow, I think it is better for RETRY to stay declaritive, i.e. "it only says we are going to restart".

while I think the it should be good to put the shutdown_worker_group in or before the _make_and_handle_scaling_decision_for_non_running_worker_group() before scaling policy to make scaling decision.

Since it is non running already, we should assume the old worker group is gone.

cc @justinvyu @matthewdeng WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me - updated it to this in 53be039

ryanaoleary and others added 3 commits March 12, 2026 06:56
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary requested a review from liulehui March 12, 2026 06:59
edoakes pushed a commit that referenced this pull request Mar 12, 2026
…61300)

## Description
This PR adds a util to check for the number of alive, complete TPU
slices in a RayCluster. This PR also adds better test coverage.

This utility is used in the Ray Train elastic policy to cap the number
of workers that can be scaled by the AutoscalingCoordinator.

## Related issues
#55162

Related PR: #61299

---------

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Co-authored-by: Mengjin Yan <mengjinyan3@gmail.com>
ryanaoleary and others added 2 commits March 13, 2026 22:49
Copy link
Copy Markdown
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tyvm

@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Mar 16, 2026
Comment on lines +168 to +169
# Fallback to the raw calculation if the strict math fails
return total_num_workers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever want this to happen? If workers_per_slice == 0 what would the expected behavior be here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workers_per_slice == 0 should be impossible since it'd indicate an invalid topology config (i.e. there is no TPU configuration where the resources required by a worker is greater than the resources on a slice, which is how 0 would be returned here). We validate here in the v2 scaling config and here in the tpu util that's called, so this should be impossible.

The expected behavior would be that we just return 0 in this function since we don't have the resources to scale any workers. Rather than a nested if statement, I'll change it to an early return of 0.

Comment on lines +158 to +160
# The number of workers scaled should be a multiple of the number of
# workers that fit on a TPU slice.
return num_available_slices * workers_per_slice
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_available_slices == 0 shouldn't we just let this also be 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think so, that should already be the case because if we get to this return and num_available_slices == 0 (as a result of num_available_slices = total_num_workers // workers_per_slice == 0 or if the get_num_ready_tpu_slices call fails) we'll just return 0.

I think you're right that below though we shouldn't # Fallback to the raw calculation if the strict math fails, and instead just let it return 0. Otherwise, if the detection math fails for some reason we could end up scheduling a partial slice which will fail since TPU requires SPMD.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary requested a review from matthewdeng March 17, 2026 03:44
@matthewdeng matthewdeng enabled auto-merge (squash) March 17, 2026 03:49
@matthewdeng matthewdeng merged commit 4397fcb into ray-project:master Mar 17, 2026
7 checks passed
rayhhome pushed a commit to rayhhome/ray that referenced this pull request Mar 17, 2026
…er (ray-project#61299)

## Description
This PR implements support for elastic training on TPUs using the
`JaxTrainer` API and the elastic scaling policy.

Specifically, this PR utilizes a new TPU utility
`get_num_ready_tpu_slices` to return the number of full, ready TPU
slices in the RayCluster and then adjusts the `_count_possible_workers`
calculation when running on TPUs to scale atomically by TPU slices. This
PR also adds comprehensive unit tests and an e2e test for the new
support.

I'll separate the `ray.util.tpu` change in a separate PR, but left it in
for now so that the tests could pass.

## Related issues
Implements milestone 3 of
ray-project#55162

---------

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
pedrojeronim0 pushed a commit to pedrojeronim0/ray that referenced this pull request Mar 23, 2026
…er (ray-project#61299)

## Description
This PR implements support for elastic training on TPUs using the
`JaxTrainer` API and the elastic scaling policy.

Specifically, this PR utilizes a new TPU utility
`get_num_ready_tpu_slices` to return the number of full, ready TPU
slices in the RayCluster and then adjusts the `_count_possible_workers`
calculation when running on TPUs to scale atomically by TPU slices. This
PR also adds comprehensive unit tests and an e2e test for the new
support.

I'll separate the `ray.util.tpu` change in a separate PR, but left it in
for now so that the tests could pass.

## Related issues
Implements milestone 3 of
ray-project#55162

---------

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Pedro Jeronimo <pedro.jeronimo@tecnico.ulisboa.pt>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants