Skip to content

[RayTrain] Manual checkpoint persistence to storage #52762

@mgharbi

Description

@mgharbi

Description

It would be neat to be able to bypass RayTrain's automatic persistence of checkpoints to the storage backend (e.g. S3 or GCS). There is currently no way around that.

In my use case, I am writing custom code to persist checkpoint to GCS using PyTorch's distributed checkpointing library (which does not require staging to disk before asynchronous upload). Ray's current report API forces a copy of my uploaded checkpoint, even though it is already stored in the cloud.

Ideally, I would like to simply "inform" the Ray TorchTrainer of that custom checkpoint, so it can be used in the Ray Result, e.g. for auto-resume functionalities.

# Current API
ray.train.report(metrics, checkpoint=persisted_ckpt)
# This will create a new checkpoint on the backend storage.

# Code to achieve the desired effect: bypass RayTrain's automatic checkpoint upload
from ray.train._internal.session import _TrainingResult
from ray.train.v2._internal.execution.context import (
    get_train_context as get_internal_train_context,
)
result = _TrainingResult(checkpoint=persisted_ckpt, metrics=metrics)
get_internal_train_context().get_result_queue().put(result)

Use case

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions