[train] Add checkpoint util functions for JaxTrainer.#60759
[train] Add checkpoint util functions for JaxTrainer.#60759siyuanfoundation wants to merge 1 commit intoray-project:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a JaxCheckpointManager to handle checkpointing for JaxTrainer by wrapping orbax.CheckpointManager. This is a significant and well-implemented feature for improving JAX support in Ray Train. The new functionality is supported by comprehensive tests for both single-host and multi-host scenarios. My review includes suggestions for improving code clarity, maintainability, and test robustness.
8ecaeca to
cfe32c8
Compare
|
/cc @ryanaoleary |
cfe32c8 to
2b88598
Compare
7ff843f to
77ac30a
Compare
liulehui
left a comment
There was a problem hiding this comment.
Could you help elaborate the requirement/functionality in the PR description?
ty!
| from ray.train._internal.session import _TrainingResult | ||
| from ray.train._internal.storage import StorageContext, _exists_at_fs_path |
There was a problem hiding this comment.
I think these are for ray train v1, this is our v2 checkpoint manager:
https://github.com/liulehui/ray/blob/oss-elastic-training/python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py#L77
I recently added a gpt-2 template: https://docs.ray.io/en/master/train/examples/jax/intro_to_jax_trainer/README.html
which use orbax for checkpointing as well,
is there any more functionality/requirement needed beyond this one?
| ) | ||
|
|
||
| # Use PyTreeCheckpointHandler for standard PyTree saving | ||
| item_handlers = { |
There was a problem hiding this comment.
Is it possible/supported for users to pass different arguments here? It might be good to expose orbax_options or something similar to define the arguments users can pass to the CheckpointManager when instantiating their JaxCheckpointManager.
| save_args = ocp_args.PyTreeSave( | ||
| item=train_state, | ||
| save_args=jax.tree.map( | ||
| lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), train_state | ||
| ), | ||
| ) |
There was a problem hiding this comment.
discussed offline,
let's try to have a similar layout with other framework currently have.
e.g.
def train_fn_per_worker(train_loop_config: dict):
checkpoint = ray.train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as temp_checkpoint_dir:
# pass in new workers and mesh/sharding config to restore
restore_args_structure = jax.tree.map(map_to_restore_args, target)
checkpoint_args = ocp_args.PyTreeRestore(
item=target, restore_args=restore_args_structure
)
model = orbax.restore(args=ocp_args.Composite(items=checkpoint_args))
# continue training
...
# save with current mesh sharding
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
success = orbax.checkpointer.save(step, args=ocp_args.Composite(items, train_state), sharding)
ray.train.report(
{"loss": 0.1},
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)
2297a66 to
c07acb2
Compare
c07acb2 to
96a15b0
Compare
Signed-off-by: siyuanfoundation <sizhang@google.com>
96a15b0 to
adc697f
Compare
| mesh=x.sharding.mesh, sharding=x.sharding | ||
| ) | ||
| if isinstance(x, (jax.Array, jax.ShapeDtypeStruct)) | ||
| and hasattr(x, "sharding") |
There was a problem hiding this comment.
Null sharding access crashes on ShapeDtypeStruct targets
Medium Severity
The guard hasattr(x, "sharding") is insufficient for jax.ShapeDtypeStruct because its sharding attribute always exists but defaults to None. When a ShapeDtypeStruct with sharding=None is passed as part of the target, hasattr returns True, then x.sharding.mesh raises an AttributeError since None has no mesh. The check needs to also verify that x.sharding is not None (e.g., using getattr(x, "sharding", None) is not None).
| """ | ||
| import orbax.checkpoint as ocp | ||
|
|
||
| checkpointer = ocp.PyTreeCheckpointer() |
There was a problem hiding this comment.
I think this can be in user's training func script instead of us forcing it.
they can choose to do
checkpointer.save(checkpoint_dir, item, force=force)
mlflow.log_model(checkpoint_dir)
without using ray.train.report.
we can leave this for user.
| restore_args = jax.tree_util.tree_map( | ||
| lambda x: type_handlers.ArrayRestoreArgs( | ||
| mesh=x.sharding.mesh, sharding=x.sharding | ||
| ) | ||
| if isinstance(x, (jax.Array, jax.ShapeDtypeStruct)) | ||
| and hasattr(x, "sharding") | ||
| else ocp.checkpoint_utils.construct_restore_args(x), | ||
| target, | ||
| is_leaf=lambda x: isinstance(x, (jax.Array, jax.ShapeDtypeStruct)), | ||
| ) |
There was a problem hiding this comment.
just for my understanding,
is the main difference here to pass in target (which would include both model definition and sharding info)
so that we can restore from a previous checkpoint?
Would it be sufficient we keep the Mesh/Sharding in the training context so that user can just use that for restoring? In this way, I think we only need a util to construct a restore_args right


Description
Add checkpoint util functions for JaxTrainer.
This is optional to use (could serve as an example). Users can always use their own checkpoint function.
Related issues
#55162
Additional information