-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[RayTrain] Manual checkpoint persistence to storage #52762
Copy link
Copy link
Open
Labels
P1Issue that should be fixed within a few weeksIssue that should be fixed within a few weekscommunity-backlogfeature-requesttrainRay Train Related IssueRay Train Related Issueusability
Description
Description
It would be neat to be able to bypass RayTrain's automatic persistence of checkpoints to the storage backend (e.g. S3 or GCS). There is currently no way around that.
In my use case, I am writing custom code to persist checkpoint to GCS using PyTorch's distributed checkpointing library (which does not require staging to disk before asynchronous upload). Ray's current report API forces a copy of my uploaded checkpoint, even though it is already stored in the cloud.
Ideally, I would like to simply "inform" the Ray TorchTrainer of that custom checkpoint, so it can be used in the Ray Result, e.g. for auto-resume functionalities.
# Current API
ray.train.report(metrics, checkpoint=persisted_ckpt)
# This will create a new checkpoint on the backend storage.
# Code to achieve the desired effect: bypass RayTrain's automatic checkpoint upload
from ray.train._internal.session import _TrainingResult
from ray.train.v2._internal.execution.context import (
get_train_context as get_internal_train_context,
)
result = _TrainingResult(checkpoint=persisted_ckpt, metrics=metrics)
get_internal_train_context().get_result_queue().put(result)Use case
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
P1Issue that should be fixed within a few weeksIssue that should be fixed within a few weekscommunity-backlogfeature-requesttrainRay Train Related IssueRay Train Related Issueusability