[train][v2] implement state management #50515
[train][v2] implement state management #50515matthewdeng merged 23 commits intoray-project:masterfrom
Conversation
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
| # self._assert_active() | ||
| # TODO: Figure out the right validation. This is needed before active. | ||
| # Consider passing in the worker group context to the constructor. | ||
| assert self._worker_group_context is not None |
There was a problem hiding this comment.
I think passing to constructor makes sense and have it be non-optional. We already don't allow people to use the same worker group object for multiple attempts.
There was a problem hiding this comment.
Good point let me try that! Extending this it might even make sense to just get rid of the create method and merge that into __init__, but I can follow up on that in a separate PR.
| def before_worker_group_shutdown(self, worker_group: WorkerGroup): | ||
| worker_group_context: WorkerGroupContext = ( | ||
| worker_group.get_worker_group_context() | ||
| ) | ||
| worker_group_poll_status: WorkerGroupPollStatus = ( | ||
| worker_group.get_latest_poll_status() | ||
| ) | ||
| if worker_group_poll_status.errors: |
There was a problem hiding this comment.
Should we pass the errors as an argument of this callback instead? So we keep the latest poll status detail hidden? I'm fine with either way.
There was a problem hiding this comment.
Hmm I guess it's true that this is not always the reason why the worker group errors. There's an implicit relationship between the poll status and why shutdown is actually called. Let me add a TODO here.
For the time being this is somewhat okay:
[Rank 0]\nTraceback (most recent call last):\n File "/Users/matt/workspace/scratch/state/fault_tolerance.py", line 9, in f\n raise Exception("test")\nException: test\n
| if self.gpu_ids: | ||
| repr_lines.append(f"{indent}gpu_ids={self.gpu_ids},") | ||
|
|
There was a problem hiding this comment.
GPU ids should be captured already in the accelerator ids dict
Signed-off-by: Matthew Deng <matt@anyscale.com>
| @routes.get("/api/train/v2/runs/v1") | ||
| @dashboard_optional_utils.init_ray_and_catch_exceptions() | ||
| @DeveloperAPI |
There was a problem hiding this comment.
Can you explain the whole v2/v1 / v1/v2 route situation in the PR description?
There was a problem hiding this comment.
I ended up adding this as a comment to the v1 route.
The existing V1 route is currently "/api/train/v2/runs", which is supposed to represent the 2nd iteration of the Train Run state.
However, since we introduced Ray Train V2, it's now ambiguous if the v2 in the route refers to Train or Train Run. So going forward I want to have it so that it adheres to "/api/train/{train_version}/runs/{run_version}".
| class TrainStateActor: | ||
| def __init__(self): | ||
| self._runs: Dict[str, TrainRun] = {} | ||
| # {run_id: {attempt_id: TrainRunAttempt}} | ||
| self._run_attempts: Dict[str, Dict[str, TrainRunAttempt]] = defaultdict(dict) |
There was a problem hiding this comment.
This is probably solved by switching to the export API and removing this actor, but this state actor might start to take up a lot of memory if there's lots of workers and attempts.
There was a problem hiding this comment.
Unfortunately this won't be solved entirely by the export API, since there is also the same dictionaries stored in the StateManager itself. I will add a NOTE.
| def __init__( | ||
| self, train_run_context: TrainRunContext, datasets: Dict[str, GenDataset] | ||
| ): | ||
| self._state_manager = TrainStateManager() |
There was a problem hiding this comment.
It seems the self._state_manager is initialized on the driver and then passed to the controller process. This _state_manager is a dynamic concept, and we are copying this object from driver to the controller.
I think, ideally, we should make the init function of Callbacks that will be run on the driver contains static concepts only. I.e., it may contain a TrainStateManagerFactory that will create a state_manager on the controller process only.
Currently, we are having two TrainStateManager objects across the program, one on the driver, another on the controller. These two object will have different contents, because the one on the driver will never be used. It is not causing any problem. But the Factory method will make it clean.
I remember we are similar discussion for Metrics Callback. I think we can revamp the Callback structure together in a near future PR.
hongpeng-guo
left a comment
There was a problem hiding this comment.
Nice Nice! left a few comments.
Signed-off-by: Matthew Deng <matt@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Should be good after these comments.
| def __init__( | ||
| self, | ||
| worker_group_context: WorkerGroupContext, | ||
| ): | ||
| self._num_workers = worker_group_context.num_workers | ||
| self._worker_group_state = None |
There was a problem hiding this comment.
For a follow-up: Let's split this DummyWorkerGroup out of test_controller into a shared test utils file. Also, I think we can add unified methods for creating the mock objects (ex: consolidate the mock constructor methods that are in test_worker_group).
| @pytest.fixture(scope="function") | ||
| def ray_start_regular(): |
There was a problem hiding this comment.
| @pytest.fixture(scope="function") | |
| def ray_start_regular(): | |
| @pytest.fixture(scope="module") | |
| def ray_start_regular(): |
There was a problem hiding this comment.
Decided against this to avoid shared state because of the global TrainStateActor
| context = MagicMock(spec=WorkerGroupContext) | ||
| context.run_attempt_id = "attempt_1" | ||
| context.num_workers = 2 | ||
| context.resources_per_worker = {"CPU": 1} |
There was a problem hiding this comment.
Dang MagicMock(spec) is pretty cool. We should use this more (ex: the DummyWorkerGroup can just be a mock with a spec of the actual worker group).
AI developer >>
| actor = MagicMock() | ||
| actor._actor_id.hex.return_value = "actor_1" |
There was a problem hiding this comment.
This mock is a bit nasty, but it's ok since Ray Core doesn't provide a good public API to get the actor id. Let's just move it to a helper function since it's used later on as well.
This PR adds state tracking capabilities to Ray Train V2. ## Key Changes ### State Management Added new state tracking system for Train V2 that captures: - Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.) - Run attempts within each training run and their statuses - Training worker metadata (ranks, node IP / PID, etc.) This is done through the following classes: - [Read][Write] The `TrainStateActor` is the centralized data access object, which is called to write data (currently in memory), and to read the data. - [Write] The `TrainStateManager` manages the Ray Train state, and writes to the TrainStateActor. - [Write] The `StateManagerCallback` implements the `ControllerCallback` and `WorkerGroupCallback` and maps actions from the `Controller` and `WorkerGroup` to the `TrainStateManager`. - [Read] The `TrainHead` exposes an endpoint for reading from `TrainStateActor`, and performs additional decoration logic before returning it. ### Schema Defined comprehensive schema for training state including: - `TrainRun` - Top-level training run information, which maps to one call to `Trainer.fit()` - `TrainRunAttempt` - Individual training attempts within a training run (e.g. fault tolerance retries). - `TrainWorker` - Worker-specific state and metrics - Status enums for runs, attempts, and actors --------- Signed-off-by: Matthew Deng <matt@anyscale.com>
This PR adds state tracking capabilities to Ray Train V2. ## Key Changes ### State Management Added new state tracking system for Train V2 that captures: - Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.) - Run attempts within each training run and their statuses - Training worker metadata (ranks, node IP / PID, etc.) This is done through the following classes: - [Read][Write] The `TrainStateActor` is the centralized data access object, which is called to write data (currently in memory), and to read the data. - [Write] The `TrainStateManager` manages the Ray Train state, and writes to the TrainStateActor. - [Write] The `StateManagerCallback` implements the `ControllerCallback` and `WorkerGroupCallback` and maps actions from the `Controller` and `WorkerGroup` to the `TrainStateManager`. - [Read] The `TrainHead` exposes an endpoint for reading from `TrainStateActor`, and performs additional decoration logic before returning it. ### Schema Defined comprehensive schema for training state including: - `TrainRun` - Top-level training run information, which maps to one call to `Trainer.fit()` - `TrainRunAttempt` - Individual training attempts within a training run (e.g. fault tolerance retries). - `TrainWorker` - Worker-specific state and metrics - Status enums for runs, attempts, and actors --------- Signed-off-by: Matthew Deng <matt@anyscale.com>
This PR adds Export API support for Ray Train state events. ## Key Changes - Added new proto messages `ExportTrainRunEventData` and `ExportTrainRunAttemptEventData` to capture training state - Created `EventLogType` enum to manage different types of export event logs - Updated `TrainStateActor` to export Train state events when export is enabled - Modified timestamp fields from milliseconds to nanoseconds (for both proto and python schema) - `start_time_ms` → `start_time_ns` - `end_time_ms` → `end_time_ns` ## Implementation Details - Train run and attempt events are now written to the `event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled - Export can be enabled either globally or specifically for Train events using environment variables: - `RAY_enable_export_api_write=1` (all events) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run events only) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train run attempt events only) Based off of #47888. Follows the new schema added in #50515. --------- Signed-off-by: Matthew Deng <matt@anyscale.com> Signed-off-by: Alan Guo <aguo@anyscale.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Alan Guo <aguo@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
This PR adds Export API support for Ray Train state events. ## Key Changes - Added new proto messages `ExportTrainRunEventData` and `ExportTrainRunAttemptEventData` to capture training state - Created `EventLogType` enum to manage different types of export event logs - Updated `TrainStateActor` to export Train state events when export is enabled - Modified timestamp fields from milliseconds to nanoseconds (for both proto and python schema) - `start_time_ms` → `start_time_ns` - `end_time_ms` → `end_time_ns` ## Implementation Details - Train run and attempt events are now written to the `event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled - Export can be enabled either globally or specifically for Train events using environment variables: - `RAY_enable_export_api_write=1` (all events) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run events only) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train run attempt events only) Based off of #47888. Follows the new schema added in #50515. --------- Signed-off-by: Matthew Deng <matt@anyscale.com> Signed-off-by: Alan Guo <aguo@anyscale.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Alan Guo <aguo@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Abrar Sheikh <abrar@anyscale.com>
This PR adds state tracking capabilities to Ray Train V2. ## Key Changes ### State Management Added new state tracking system for Train V2 that captures: - Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.) - Run attempts within each training run and their statuses - Training worker metadata (ranks, node IP / PID, etc.) This is done through the following classes: - [Read][Write] The `TrainStateActor` is the centralized data access object, which is called to write data (currently in memory), and to read the data. - [Write] The `TrainStateManager` manages the Ray Train state, and writes to the TrainStateActor. - [Write] The `StateManagerCallback` implements the `ControllerCallback` and `WorkerGroupCallback` and maps actions from the `Controller` and `WorkerGroup` to the `TrainStateManager`. - [Read] The `TrainHead` exposes an endpoint for reading from `TrainStateActor`, and performs additional decoration logic before returning it. ### Schema Defined comprehensive schema for training state including: - `TrainRun` - Top-level training run information, which maps to one call to `Trainer.fit()` - `TrainRunAttempt` - Individual training attempts within a training run (e.g. fault tolerance retries). - `TrainWorker` - Worker-specific state and metrics - Status enums for runs, attempts, and actors --------- Signed-off-by: Matthew Deng <matt@anyscale.com>
This PR adds Export API support for Ray Train state events. ## Key Changes - Added new proto messages `ExportTrainRunEventData` and `ExportTrainRunAttemptEventData` to capture training state - Created `EventLogType` enum to manage different types of export event logs - Updated `TrainStateActor` to export Train state events when export is enabled - Modified timestamp fields from milliseconds to nanoseconds (for both proto and python schema) - `start_time_ms` → `start_time_ns` - `end_time_ms` → `end_time_ns` ## Implementation Details - Train run and attempt events are now written to the `event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled - Export can be enabled either globally or specifically for Train events using environment variables: - `RAY_enable_export_api_write=1` (all events) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run events only) - `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train run attempt events only) Based off of ray-project#47888. Follows the new schema added in ray-project#50515. --------- Signed-off-by: Matthew Deng <matt@anyscale.com> Signed-off-by: Alan Guo <aguo@anyscale.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Alan Guo <aguo@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Why are these changes needed?
This PR adds state tracking capabilities to Ray Train V2.
Key Changes
State Management
Added new state tracking system for Train V2 that captures:
This is done through the following classes:
TrainStateActoris the centralized data access object, which is called to write data (currently in memory), and to read the data.TrainStateManagermanages the Ray Train state, and writes to the TrainStateActor.StateManagerCallbackimplements theControllerCallbackandWorkerGroupCallbackand maps actions from theControllerandWorkerGroupto theTrainStateManager.TrainHeadexposes an endpoint for reading fromTrainStateActor, and performs additional decoration logic before returning it.Schema
Defined comprehensive schema for training state including:
TrainRun- Top-level training run information, which maps to one call toTrainer.fit()TrainRunAttempt- Individual training attempts within a training run (e.g. fault tolerance retries).TrainWorker- Worker-specific state and metricsUsage
To enable state tracking in Train V2:
export RAY_TRAIN_ENABLE_STATE_TRACKING=1repro.py:RAY_TRAIN_ENABLE_STATE_TRACKING=1 RAY_TRAIN_V2_ENABLED=1 python repro.pyResults are available at http://localhost:8265/api/train/v2/runs/v1:
Future Work
Related issue number
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.