Skip to content

[train][checkpoint] Add checkpoint_upload_mode to ray.train.report#55637

Merged
justinvyu merged 19 commits intoray-project:masterfrom
TimothySeah:tseah/async-checkpoint
Sep 11, 2025
Merged

[train][checkpoint] Add checkpoint_upload_mode to ray.train.report#55637
justinvyu merged 19 commits intoray-project:masterfrom
TimothySeah:tseah/async-checkpoint

Conversation

@TimothySeah
Copy link
Copy Markdown
Contributor

@TimothySeah TimothySeah commented Aug 15, 2025

Summary

Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD.

  • Introduce per-worker checkpoint counters to preserve report order.
  • Use a thread pool to limit concurrent uploads and avoid OOM.
  • Wrap the training function to wait for pending uploads before exiting.
  • Add delete_local_checkpoint_after_upload to control temporary local directory cleanup.

Implementation Summary

This PR implements async checkpointing by

  • Adding a checkpoint_upload_mode to ray.train.report with three options
  • Maintaining internal-only num_reported_checkpoints and num_attempted_reported_checkpoints counters on the TrainContext
  • Regardless of which combinations of checkpoint_upload_mode the different Ray Train workers are doing, we want to upload checkpoints in the order they were ray.train.reported. Therefore, each Ray Train worker waits for its turn (num_reported_checkpoints == current_report_attempt_number - 1) before adding its checkpoint to the result queue.
  • Uploading too many checkpoints concurrently runs the risk of OOM-ing. I included a ThreadPoolExecutor to guard against adding too many checkpoint upload threads.
  • I changed run_train_fn to wrap the train_fn in train_fn_that_waits_for_threads because otherwise, we could be in the following situation: 1) train function exits with pending report threads and worker status is finished 2) controller sees finished status and shuts down worker group 3) result.fit does not return all the reported checkpoints/metrics
    • I decided to implement "early exit" in ThreadRunner but "wait for threads" as a wrapper function because in the former case, that is the cleanest way for a nested thread to cause the entire worker to exit early, but in this case, the target function is able to wait for the threads that it creates without complicating the ThreadRunner abstraction.

A few other notes:

  • I decided to only add Checkpoints (instead of Checkpoint ObjectRefs) to the result queue because:
    • If we went with the ObjectRef approach, the controller would create a Ray task that updates controller state. This "driver creates task that updates driver" pattern is unwieldy to implement.
    • Every worker must upload its checkpoint so it makes sense to confine this logic to the worker rather than making the controller even more complicated than it already is.
  • One interesting corner case I found while unit testing this PR is that async checkpoint uploads don't work with temporary directories because we might exit the temporary directory's scope before we kick off the checkpoint upload.

API Changes

This PR's only API changes are adding the following two arguments to ray.train.report:

  • checkpoint_upload_mode:
    • SYNC: synchronous upload - current and default behavior
    • ASYNC: asynchronous upload - the main goal of this PR
    • NONE: do not upload checkpoint - useful when users upload checkpoints themselves
  • delete_local_checkpoint_after_upload: Whether to delete the checkpoint after uploading it. Users generally won't need to set this since each checkpoint upload mode has its own default:
    • SYNC: False because users will generally use tempfile
    • ASYNC: True because users can't use tempfile - see previous section for explanation
    • NO_UPLOAD: False because there is no local directory to delete

Here's a simple example of this API in action:

def train_func():
    ...
    ray.train.report(
        metrics={},
        checkpoint=Checkpoint.from_directory(checkpoint_dir),
        checkpoint_upload_mode=CheckpointUploadMode.ASYNC,
    )

Testing

Looks like async reporting is indeed faster with the same loss on the pytorch ray train example: https://docs.ray.io/en/latest/train/getting-started-pytorch.html

Sync mode

(RayTrainWorker pid=6316, ip=10.0.180.165) Blocked times: [0.2751896381378174, 0.28496718406677246, 0.26192378997802734, 0.25046420097351074, 0.2681725025177002, 0.2644937038421631, 0.27478623390197754, 0.2887108325958252, 0.36760926246643066, 0.32657504081726074] with total 2.8628923892974854

{'loss': 0.04541657865047455, 'epoch': 9}

3m3s

Async mode

Async mode: (RayTrainWorker pid=8960, ip=10.0.185.254) Blocked times: [0.005418062210083008, 0.004388093948364258, 0.0045735836029052734, 0.004605531692504883, 0.004551887512207031, 0.0025453567504882812, 0.021490812301635742, 0.014325380325317383, 0.0050373077392578125, 0.004266977310180664] with total 0.07120299339294434

{'loss': 0.04778982326388359, 'epoch': 9}

2m57s with only ~0.22s blocking time when waiting for the last checkpoint upload:

I | 2025-08-29 20:44:25.066 | Blocked times: [0.005418062210083008, 0.004388093948364258, 0.0045735836029052734, 0.004605531692504883, 0.004551887512207031, 0.0025453567504882812, 0.021490812301635742, 0.014325380325317383, 0.0050373077392578125, 0.004266977310180664] with total 0.07120299339294434
-- | -- | --
I | 2025-08-29 20:44:25.289 | Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-29_20-44-24.990364)

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants