[Train] Update elastic policy to handle multi-host TPUs with JaxTrainer#61299
[Train] Update elastic policy to handle multi-host TPUs with JaxTrainer#61299matthewdeng merged 23 commits intoray-project:masterfrom
Conversation
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
|
Relevant output from e2e test: Scaling policy waits until a full slice is available: Once a complete slice is available, controller cleanly upscales: When a TPU worker is preempted, the policy recovers and scales down a slice: Test completes: |
|
cc: @siyuanfoundation for help testing with guide showcasing Orbax checkpointing with this new support (i.e. what was being implemented in #60759) |
There was a problem hiding this comment.
Code Review
This pull request introduces support for elastic training on multi-host TPUs with the JaxTrainer. The core changes involve a new utility get_num_ready_tpu_slices to accurately count available TPU slices and modifications to the elastic scaling policy to scale atomically based on these slices. The implementation appears robust and is well-supported by comprehensive unit tests and a new end-to-end test that simulates various scaling scenarios. The changes are clear and address the requirements for elastic TPU training effectively. I have one suggestion to improve the e2e test's implementation to avoid modifying global state.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
|
A future PR may need to extend support for requesting labels to the |
| total_num_workers = min(int(total_num_workers), self.scaling_config.max_workers) | ||
|
|
||
| # Multi-host TPUs are scheduled atomically in interconnected slices defined by a topology. | ||
| # Floor the total available workers to the nearest multiple of the slice size. |
There was a problem hiding this comment.
I would prefer throw error here instead of flooring it. can we validate the min_workers/max_workers in the _validate_tpu_config: https://github.com/ray-project/ray/blob/master/python/ray/train/v2/api/config.py#L170
There was a problem hiding this comment.
I think we already are validating min_workers and max_workers here:
ray/python/ray/train/v2/api/config.py
Line 228 in add50c3
min_workers is 4 and resources for 7 TPU nodes are available -> we want the policy to upscale 4 workers and then wait until resources for 1 additional worker are available to scale up to 8). Without the flooring logic currently, even if min_workers and max_workers are validated correctly the policy will try to upscale whenever we have available resources for > min_workers, even if the resources available are not a full TPU slice.
| and self.scaling_config.topology | ||
| and self.scaling_config.accelerator_type | ||
| ): | ||
| from ray.util.tpu import get_num_ready_tpu_slices, get_tpu_worker_resources |
There was a problem hiding this comment.
it seems for calculating possible tpu workers, we are not using the allocated_resources, should we instead of putting this maybe _count_possible_tpu_workers instead?
There was a problem hiding this comment.
on second thought, I think we still wanna use allocated_resources, it is possible that in a cluster, the "ALIVE" TPU nodes are used by both data loading and training, which is the idea of the autoscaling_coordinator to balance/prioritize the resource within the cluster. so it might not be correct by checking the alive node within the ray cluster for training to get num_ready_tpu_slices.
There was a problem hiding this comment.
Yeah I think we are using allocated_resources because we take the count calculated using that value and divide it to get the number of available TPU slices:
num_complete_slices = total_num_workers // workers_per_slice
I think you're right that num_ready_tpu_slices is only returning the Alive slices, which doesn't necessarily mean they're available since they could be used for Data loading, etc. but we're already accounting for this by calculating:
num_complete_slices = min(num_complete_slices, num_ready_slices)
So we use the allocated_resources count to determine how many "free" TPU slices we have, relying on the autoscaler as the source of truth, and then cap that value by the number of "Alive" TPU slices that are labeled as part of the same ray.io/tpu-slice-name slice. I think I could make the var names more clear to describe what's happening.
There was a problem hiding this comment.
suppose 4 slices (slice 1-4) are allocated to task1, and 2 slices (slice 5-6) allocated to task2, each slice has 4 hosts; then suppose slice 1-4 each have 1 dead node.
num_complete_slices = total_num_workers // workers_per_slice = 4*3/4=3
num_ready_slices will be the 2 healthy slices of task2.
In this case, wouldn't num_complete_slices = min(num_complete_slices, num_ready_slices)=2 but actually it should be 0?
There was a problem hiding this comment.
Yeah I think you're correct that it should evaluate to 0. I think the current code works in practice because TPU pods with KubeRay will be scaled down atomically, so the surviving orphaned nodes in Slices 1-4 will be rapidly terminated. Even during the brief window before that atomic teardown finishes, GCS enforces the bundle_label_selector set in the SlicePlacementGroups it creates for the 2 slices, so the placement request remains in a PENDING state even when there are fractured slices until the whole slice is healthy and available.
Even though the current solution might work I agree it's hacky and it'd be better if the utility returned slices that are both complete, and available (rather than just feasible). I didn't previously because I was worried about race conditions or inconsistencies in resource availability if the same workers were temporarily used for other tasks like data loading (i.e. we'd still want the PG to be created, but just remain pending until the other Task was done and it could be scheduled).
There was a problem hiding this comment.
Done in d266c8c, the helper now relies on the State API to get node resource usage.
There was a problem hiding this comment.
Thanks! I think it is pretty robust now.
python/ray/train/v2/_internal/execution/scaling_policy/elastic.py
Outdated
Show resolved
Hide resolved
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
python/ray/train/v2/_internal/execution/scaling_policy/elastic.py
Outdated
Show resolved
Hide resolved
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
python/ray/train/v2/_internal/execution/scaling_policy/elastic.py
Outdated
Show resolved
Hide resolved
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
|
/lgtm |
python/ray/train/v2/_internal/execution/scaling_policy/elastic.py
Outdated
Show resolved
Hide resolved
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
| ) | ||
|
|
||
| if failure_decision == FailureDecision.RETRY: | ||
| if self._worker_group: |
There was a problem hiding this comment.
@liulehui this is what ended up fixing the issue we talked about offline. Without this line to shutdown the worker group, the elastic policy would stall on the resize decision after killing 1 node. This was happening because before it releases the placement group and shuts down the worker group, the policy calculates the number of possible workers for the resize decision - but since our TPU logic checks for available slices and the slices are still held by the worker group, it returned resources for 0 workers.
There was a problem hiding this comment.
ah I see.
In the current state transition flow, I think it is better for RETRY to stay declaritive, i.e. "it only says we are going to restart".
while I think the it should be good to put the shutdown_worker_group in or before the _make_and_handle_scaling_decision_for_non_running_worker_group() before scaling policy to make scaling decision.
Since it is non running already, we should assume the old worker group is gone.
cc @justinvyu @matthewdeng WDYT?
There was a problem hiding this comment.
That makes sense to me - updated it to this in 53be039
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
…61300) ## Description This PR adds a util to check for the number of alive, complete TPU slices in a RayCluster. This PR also adds better test coverage. This utility is used in the Ray Train elastic policy to cap the number of workers that can be scaled by the AutoscalingCoordinator. ## Related issues #55162 Related PR: #61299 --------- Signed-off-by: ryanaoleary <ryanaoleary@google.com> Co-authored-by: Mengjin Yan <mengjinyan3@gmail.com>
…rker_group instead Signed-off-by: ryanaoleary <ryanaoleary@google.com>
| # Fallback to the raw calculation if the strict math fails | ||
| return total_num_workers |
There was a problem hiding this comment.
Do we ever want this to happen? If workers_per_slice == 0 what would the expected behavior be here?
There was a problem hiding this comment.
workers_per_slice == 0 should be impossible since it'd indicate an invalid topology config (i.e. there is no TPU configuration where the resources required by a worker is greater than the resources on a slice, which is how 0 would be returned here). We validate here in the v2 scaling config and here in the tpu util that's called, so this should be impossible.
The expected behavior would be that we just return 0 in this function since we don't have the resources to scale any workers. Rather than a nested if statement, I'll change it to an early return of 0.
| # The number of workers scaled should be a multiple of the number of | ||
| # workers that fit on a TPU slice. | ||
| return num_available_slices * workers_per_slice |
There was a problem hiding this comment.
If num_available_slices == 0 shouldn't we just let this also be 0?
There was a problem hiding this comment.
Yeah I think so, that should already be the case because if we get to this return and num_available_slices == 0 (as a result of num_available_slices = total_num_workers // workers_per_slice == 0 or if the get_num_ready_tpu_slices call fails) we'll just return 0.
I think you're right that below though we shouldn't # Fallback to the raw calculation if the strict math fails, and instead just let it return 0. Otherwise, if the detection math fails for some reason we could end up scheduling a partial slice which will fail since TPU requires SPMD.
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
…er (ray-project#61299) ## Description This PR implements support for elastic training on TPUs using the `JaxTrainer` API and the elastic scaling policy. Specifically, this PR utilizes a new TPU utility `get_num_ready_tpu_slices` to return the number of full, ready TPU slices in the RayCluster and then adjusts the `_count_possible_workers` calculation when running on TPUs to scale atomically by TPU slices. This PR also adds comprehensive unit tests and an e2e test for the new support. I'll separate the `ray.util.tpu` change in a separate PR, but left it in for now so that the tests could pass. ## Related issues Implements milestone 3 of ray-project#55162 --------- Signed-off-by: ryanaoleary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…er (ray-project#61299) ## Description This PR implements support for elastic training on TPUs using the `JaxTrainer` API and the elastic scaling policy. Specifically, this PR utilizes a new TPU utility `get_num_ready_tpu_slices` to return the number of full, ready TPU slices in the RayCluster and then adjusts the `_count_possible_workers` calculation when running on TPUs to scale atomically by TPU slices. This PR also adds comprehensive unit tests and an e2e test for the new support. I'll separate the `ray.util.tpu` change in a separate PR, but left it in for now so that the tests could pass. ## Related issues Implements milestone 3 of ray-project#55162 --------- Signed-off-by: ryanaoleary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Pedro Jeronimo <pedro.jeronimo@tecnico.ulisboa.pt>
Description
This PR implements support for elastic training on TPUs using the
JaxTrainerAPI and the elastic scaling policy.Specifically, this PR utilizes a new TPU utility
get_num_ready_tpu_slicesto return the number of full, ready TPU slices in the RayCluster and then adjusts the_count_possible_workerscalculation when running on TPUs to scale atomically by TPU slices. This PR also adds comprehensive unit tests and an e2e test for the new support.I'll separate the
ray.util.tpuchange in a separate PR, but left it in for now so that the tests could pass.Related issues
Implements milestone 3 of #55162