Skip to content

[train][v2] implement state management #50515

Merged
matthewdeng merged 23 commits intoray-project:masterfrom
matthewdeng:v2/state
Feb 14, 2025
Merged

[train][v2] implement state management #50515
matthewdeng merged 23 commits intoray-project:masterfrom
matthewdeng:v2/state

Conversation

@matthewdeng
Copy link
Copy Markdown
Contributor

@matthewdeng matthewdeng commented Feb 13, 2025

Why are these changes needed?

This PR adds state tracking capabilities to Ray Train V2.

Key Changes

State Management

Added new state tracking system for Train V2 that captures:

  • Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.)
  • Run attempts within each training run and their statuses
  • Training worker metadata (ranks, node IP / PID, etc.)

This is done through the following classes:

  • [Read][Write] The TrainStateActor is the centralized data access object, which is called to write data (currently in memory), and to read the data.
  • [Write] The TrainStateManager manages the Ray Train state, and writes to the TrainStateActor.
  • [Write] The StateManagerCallback implements the ControllerCallback and WorkerGroupCallback and maps actions from the Controller and WorkerGroup to the TrainStateManager.
  • [Read] The TrainHead exposes an endpoint for reading from TrainStateActor, and performs additional decoration logic before returning it.

image

Schema

Defined comprehensive schema for training state including:

  • TrainRun - Top-level training run information, which maps to one call to Trainer.fit()
  • TrainRunAttempt - Individual training attempts within a training run (e.g. fault tolerance retries).
  • TrainWorker - Worker-specific state and metrics
  • Status enums for runs, attempts, and actors

Usage

To enable state tracking in Train V2: export RAY_TRAIN_ENABLE_STATE_TRACKING=1

repro.py:

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import time

def train_func():
    time.sleep(100)

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=2),
)
trainer.fit()

RAY_TRAIN_ENABLE_STATE_TRACKING=1 RAY_TRAIN_V2_ENABLED=1 python repro.py

Results are available at http://localhost:8265/api/train/v2/runs/v1:

Screenshot 2025-02-13 at 6 01 30 PM

Future Work

  • Implement Export API for persisting data
  • Implement ABORTED status
  • Improve formatting of status details / error messages

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +640 to +643
# self._assert_active()
# TODO: Figure out the right validation. This is needed before active.
# Consider passing in the worker group context to the constructor.
assert self._worker_group_context is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think passing to constructor makes sense and have it be non-optional. We already don't allow people to use the same worker group object for multiple attempts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point let me try that! Extending this it might even make sense to just get rid of the create method and merge that into __init__, but I can follow up on that in a separate PR.

Comment on lines +115 to +122
def before_worker_group_shutdown(self, worker_group: WorkerGroup):
worker_group_context: WorkerGroupContext = (
worker_group.get_worker_group_context()
)
worker_group_poll_status: WorkerGroupPollStatus = (
worker_group.get_latest_poll_status()
)
if worker_group_poll_status.errors:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass the errors as an argument of this callback instead? So we keep the latest poll status detail hidden? I'm fine with either way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I guess it's true that this is not always the reason why the worker group errors. There's an implicit relationship between the poll status and why shutdown is actually called. Let me add a TODO here.

For the time being this is somewhat okay:

[Rank 0]\nTraceback (most recent call last):\n File "/Users/matt/workspace/scratch/state/fault_tolerance.py", line 9, in f\n raise Exception("test")\nException: test\n

Comment on lines +57 to +59
if self.gpu_ids:
repr_lines.append(f"{indent}gpu_ids={self.gpu_ids},")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU ids should be captured already in the accelerator ids dict

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Comment on lines +41 to +43
@routes.get("/api/train/v2/runs/v1")
@dashboard_optional_utils.init_ray_and_catch_exceptions()
@DeveloperAPI
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the whole v2/v1 / v1/v2 route situation in the PR description?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up adding this as a comment to the v1 route.

The existing V1 route is currently "/api/train/v2/runs", which is supposed to represent the 2nd iteration of the Train Run state.

However, since we introduced Ray Train V2, it's now ambiguous if the v2 in the route refers to Train or Train Run. So going forward I want to have it so that it adheres to "/api/train/{train_version}/runs/{run_version}".

Comment on lines +13 to +17
class TrainStateActor:
def __init__(self):
self._runs: Dict[str, TrainRun] = {}
# {run_id: {attempt_id: TrainRunAttempt}}
self._run_attempts: Dict[str, Dict[str, TrainRunAttempt]] = defaultdict(dict)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably solved by switching to the export API and removing this actor, but this state actor might start to take up a lot of memory if there's lots of workers and attempts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this won't be solved entirely by the export API, since there is also the same dictionaries stored in the StateManager itself. I will add a NOTE.

Signed-off-by: Matthew Deng <matt@anyscale.com>
def __init__(
self, train_run_context: TrainRunContext, datasets: Dict[str, GenDataset]
):
self._state_manager = TrainStateManager()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the self._state_manager is initialized on the driver and then passed to the controller process. This _state_manager is a dynamic concept, and we are copying this object from driver to the controller.

I think, ideally, we should make the init function of Callbacks that will be run on the driver contains static concepts only. I.e., it may contain a TrainStateManagerFactory that will create a state_manager on the controller process only.

Currently, we are having two TrainStateManager objects across the program, one on the driver, another on the controller. These two object will have different contents, because the one on the driver will never be used. It is not causing any problem. But the Factory method will make it clean.

I remember we are similar discussion for Metrics Callback. I think we can revamp the Callback structure together in a near future PR.

Copy link
Copy Markdown
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice Nice! left a few comments.

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good after these comments.

Comment on lines +51 to 56
def __init__(
self,
worker_group_context: WorkerGroupContext,
):
self._num_workers = worker_group_context.num_workers
self._worker_group_state = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a follow-up: Let's split this DummyWorkerGroup out of test_controller into a shared test utils file. Also, I think we can add unified methods for creating the mock objects (ex: consolidate the mock constructor methods that are in test_worker_group).

Comment on lines +42 to +43
@pytest.fixture(scope="function")
def ray_start_regular():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.fixture(scope="function")
def ray_start_regular():
@pytest.fixture(scope="module")
def ray_start_regular():

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided against this to avoid shared state because of the global TrainStateActor

Comment on lines +57 to +60
context = MagicMock(spec=WorkerGroupContext)
context.run_attempt_id = "attempt_1"
context.num_workers = 2
context.resources_per_worker = {"CPU": 1}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dang MagicMock(spec) is pretty cool. We should use this more (ex: the DummyWorkerGroup can just be a mock with a spec of the actual worker group).

AI developer >>

Comment on lines +66 to +67
actor = MagicMock()
actor._actor_id.hex.return_value = "actor_1"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mock is a bit nasty, but it's ok since Ray Core doesn't provide a good public API to get the actor id. Let's just move it to a helper function since it's used later on as well.

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Matthew Deng <matt@anyscale.com>
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM :neckbeard:

@matthewdeng matthewdeng marked this pull request as ready for review February 14, 2025 20:12
@matthewdeng matthewdeng enabled auto-merge (squash) February 14, 2025 20:13
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Feb 14, 2025
Signed-off-by: Matthew Deng <matt@anyscale.com>
@github-actions github-actions bot disabled auto-merge February 14, 2025 20:50
@matthewdeng matthewdeng enabled auto-merge (squash) February 14, 2025 20:53
@matthewdeng matthewdeng merged commit 4484e73 into ray-project:master Feb 14, 2025
3 checks passed
xsuler pushed a commit to antgroup/ant-ray that referenced this pull request Mar 4, 2025
This PR adds state tracking capabilities to Ray Train V2.

## Key Changes
### State Management
Added new state tracking system for Train V2 that captures:
- Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.)
- Run attempts within each training run and their statuses
- Training worker metadata (ranks, node IP / PID, etc.)

This is done through the following classes:

- [Read][Write] The `TrainStateActor` is the centralized data access
object, which is called to write data (currently in memory), and to read
the data.
- [Write] The `TrainStateManager` manages the Ray Train state, and
writes to the TrainStateActor.
- [Write] The `StateManagerCallback` implements the `ControllerCallback`
and `WorkerGroupCallback` and maps actions from the `Controller` and
`WorkerGroup` to the `TrainStateManager`.
- [Read] The `TrainHead` exposes an endpoint for reading from
`TrainStateActor`, and performs additional decoration logic before
returning it.

### Schema
Defined comprehensive schema for training state including:

- `TrainRun` - Top-level training run information, which maps to one
call to `Trainer.fit()`
- `TrainRunAttempt` - Individual training attempts within a training run
(e.g. fault tolerance retries).
- `TrainWorker` - Worker-specific state and metrics
- Status enums for runs, attempts, and actors

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
xsuler pushed a commit to antgroup/ant-ray that referenced this pull request Mar 4, 2025
This PR adds state tracking capabilities to Ray Train V2.

## Key Changes
### State Management
Added new state tracking system for Train V2 that captures:
- Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.)
- Run attempts within each training run and their statuses
- Training worker metadata (ranks, node IP / PID, etc.)

This is done through the following classes:

- [Read][Write] The `TrainStateActor` is the centralized data access
object, which is called to write data (currently in memory), and to read
the data.
- [Write] The `TrainStateManager` manages the Ray Train state, and
writes to the TrainStateActor.
- [Write] The `StateManagerCallback` implements the `ControllerCallback`
and `WorkerGroupCallback` and maps actions from the `Controller` and
`WorkerGroup` to the `TrainStateManager`.
- [Read] The `TrainHead` exposes an endpoint for reading from
`TrainStateActor`, and performs additional decoration logic before
returning it.

### Schema
Defined comprehensive schema for training state including:

- `TrainRun` - Top-level training run information, which maps to one
call to `Trainer.fit()`
- `TrainRunAttempt` - Individual training attempts within a training run
(e.g. fault tolerance retries).
- `TrainWorker` - Worker-specific state and metrics
- Status enums for runs, attempts, and actors

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
matthewdeng added a commit that referenced this pull request Mar 4, 2025
This PR adds Export API support for Ray Train state events.

## Key Changes

- Added new proto messages `ExportTrainRunEventData` and
`ExportTrainRunAttemptEventData` to capture training state
- Created `EventLogType` enum to manage different types of export event
logs
- Updated `TrainStateActor` to export Train state events when export is
enabled
- Modified timestamp fields from milliseconds to nanoseconds (for both
proto and python schema)
  - `start_time_ms` → `start_time_ns`
  - `end_time_ms` → `end_time_ns`

## Implementation Details

- Train run and attempt events are now written to the
`event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled
- Export can be enabled either globally or specifically for Train events
using environment variables:
  - `RAY_enable_export_api_write=1` (all events)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run
events only)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train
run attempt events only)

Based off of #47888.
Follows the new schema added in
#50515.

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Alan Guo <aguo@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Alan Guo <aguo@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
abrarsheikh pushed a commit that referenced this pull request Mar 8, 2025
This PR adds Export API support for Ray Train state events.

## Key Changes

- Added new proto messages `ExportTrainRunEventData` and
`ExportTrainRunAttemptEventData` to capture training state
- Created `EventLogType` enum to manage different types of export event
logs
- Updated `TrainStateActor` to export Train state events when export is
enabled
- Modified timestamp fields from milliseconds to nanoseconds (for both
proto and python schema)
  - `start_time_ms` → `start_time_ns`
  - `end_time_ms` → `end_time_ns`

## Implementation Details

- Train run and attempt events are now written to the
`event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled
- Export can be enabled either globally or specifically for Train events
using environment variables:
  - `RAY_enable_export_api_write=1` (all events)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run
events only)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train
run attempt events only)

Based off of #47888.
Follows the new schema added in
#50515.

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Alan Guo <aguo@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Alan Guo <aguo@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Abrar Sheikh <abrar@anyscale.com>
park12sj pushed a commit to park12sj/ray that referenced this pull request Mar 18, 2025
This PR adds state tracking capabilities to Ray Train V2.

## Key Changes
### State Management
Added new state tracking system for Train V2 that captures:
- Training run status (INITIALIZING, SCHEDULING, RUNNING, etc.)
- Run attempts within each training run and their statuses
- Training worker metadata (ranks, node IP / PID, etc.)

This is done through the following classes:

- [Read][Write] The `TrainStateActor` is the centralized data access
object, which is called to write data (currently in memory), and to read
the data.
- [Write] The `TrainStateManager` manages the Ray Train state, and
writes to the TrainStateActor.
- [Write] The `StateManagerCallback` implements the `ControllerCallback`
and `WorkerGroupCallback` and maps actions from the `Controller` and
`WorkerGroup` to the `TrainStateManager`.
- [Read] The `TrainHead` exposes an endpoint for reading from
`TrainStateActor`, and performs additional decoration logic before
returning it.

### Schema
Defined comprehensive schema for training state including:

- `TrainRun` - Top-level training run information, which maps to one
call to `Trainer.fit()`
- `TrainRunAttempt` - Individual training attempts within a training run
(e.g. fault tolerance retries).
- `TrainWorker` - Worker-specific state and metrics
- Status enums for runs, attempts, and actors

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
park12sj pushed a commit to park12sj/ray that referenced this pull request Mar 18, 2025
This PR adds Export API support for Ray Train state events.

## Key Changes

- Added new proto messages `ExportTrainRunEventData` and
`ExportTrainRunAttemptEventData` to capture training state
- Created `EventLogType` enum to manage different types of export event
logs
- Updated `TrainStateActor` to export Train state events when export is
enabled
- Modified timestamp fields from milliseconds to nanoseconds (for both
proto and python schema)
  - `start_time_ms` → `start_time_ns`
  - `end_time_ms` → `end_time_ns`

## Implementation Details

- Train run and attempt events are now written to the
`event_EXPORT_TRAIN_STATE.log` log file when the export API is enabled
- Export can be enabled either globally or specifically for Train events
using environment variables:
  - `RAY_enable_export_api_write=1` (all events)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN` (Train run
events only)
- `RAY_enable_export_api_write_config=EXPORT_TRAIN_RUN_ATTEMPT` (Train
run attempt events only)

Based off of ray-project#47888.
Follows the new schema added in
ray-project#50515.

---------

Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: Alan Guo <aguo@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Alan Guo <aguo@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-backlog go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants