Skip to content

[train] Add checkpoint util functions for JaxTrainer.#60759

Closed
siyuanfoundation wants to merge 1 commit intoray-project:masterfrom
siyuanfoundation:jax-checkpoint
Closed

[train] Add checkpoint util functions for JaxTrainer.#60759
siyuanfoundation wants to merge 1 commit intoray-project:masterfrom
siyuanfoundation:jax-checkpoint

Conversation

@siyuanfoundation
Copy link
Copy Markdown
Contributor

@siyuanfoundation siyuanfoundation commented Feb 4, 2026

Description

Add checkpoint util functions for JaxTrainer.

This is optional to use (could serve as an example). Users can always use their own checkpoint function.

Related issues

#55162

Additional information

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a JaxCheckpointManager to handle checkpointing for JaxTrainer by wrapping orbax.CheckpointManager. This is a significant and well-implemented feature for improving JAX support in Ray Train. The new functionality is supported by comprehensive tests for both single-host and multi-host scenarios. My review includes suggestions for improving code clarity, maintainability, and test robustness.

@siyuanfoundation siyuanfoundation force-pushed the jax-checkpoint branch 3 times, most recently from 8ecaeca to cfe32c8 Compare February 6, 2026 14:51
@siyuanfoundation siyuanfoundation marked this pull request as ready for review February 6, 2026 15:30
@siyuanfoundation
Copy link
Copy Markdown
Contributor Author

/cc @ryanaoleary

@matthewdeng matthewdeng requested a review from liulehui February 6, 2026 18:03
@ray-gardener ray-gardener bot added the community-contribution Contributed by the community label Feb 6, 2026
@liulehui liulehui self-assigned this Feb 9, 2026
Copy link
Copy Markdown
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!

Copy link
Copy Markdown
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help elaborate the requirement/functionality in the PR description?

ty!

Comment on lines +6 to +7
from ray.train._internal.session import _TrainingResult
from ray.train._internal.storage import StorageContext, _exists_at_fs_path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are for ray train v1, this is our v2 checkpoint manager:
https://github.com/liulehui/ray/blob/oss-elastic-training/python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py#L77

I recently added a gpt-2 template: https://docs.ray.io/en/master/train/examples/jax/intro_to_jax_trainer/README.html
which use orbax for checkpointing as well,

is there any more functionality/requirement needed beyond this one?

)

# Use PyTreeCheckpointHandler for standard PyTree saving
item_handlers = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible/supported for users to pass different arguments here? It might be good to expose orbax_options or something similar to define the arguments users can pass to the CheckpointManager when instantiating their JaxCheckpointManager.

Comment on lines +112 to +117
save_args = ocp_args.PyTreeSave(
item=train_state,
save_args=jax.tree.map(
lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), train_state
),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline,

let's try to have a similar layout with other framework currently have.

e.g.

def train_fn_per_worker(train_loop_config: dict):
    
    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as temp_checkpoint_dir:
            
            # pass in new workers and mesh/sharding config to restore
            restore_args_structure = jax.tree.map(map_to_restore_args, target)
            checkpoint_args = ocp_args.PyTreeRestore(
                item=target, restore_args=restore_args_structure
            )
            model = orbax.restore(args=ocp_args.Composite(items=checkpoint_args))

        # continue training 
        ...

    # save with current mesh sharding
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        
        success = orbax.checkpointer.save(step, args=ocp_args.Composite(items, train_state), sharding)
        ray.train.report(
            {"loss": 0.1},
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

@siyuanfoundation siyuanfoundation force-pushed the jax-checkpoint branch 2 times, most recently from 2297a66 to c07acb2 Compare February 17, 2026 21:11
@siyuanfoundation siyuanfoundation changed the title [train] Add checkpoint manager for JaxTrainer. [train] Add checkpoint util functions for JaxTrainer. Feb 17, 2026
Signed-off-by: siyuanfoundation <sizhang@google.com>
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

mesh=x.sharding.mesh, sharding=x.sharding
)
if isinstance(x, (jax.Array, jax.ShapeDtypeStruct))
and hasattr(x, "sharding")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Null sharding access crashes on ShapeDtypeStruct targets

Medium Severity

The guard hasattr(x, "sharding") is insufficient for jax.ShapeDtypeStruct because its sharding attribute always exists but defaults to None. When a ShapeDtypeStruct with sharding=None is passed as part of the target, hasattr returns True, then x.sharding.mesh raises an AttributeError since None has no mesh. The check needs to also verify that x.sharding is not None (e.g., using getattr(x, "sharding", None) is not None).

Fix in Cursor Fix in Web

"""
import orbax.checkpoint as ocp

checkpointer = ocp.PyTreeCheckpointer()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be in user's training func script instead of us forcing it.

they can choose to do

checkpointer.save(checkpoint_dir, item, force=force)
mlflow.log_model(checkpoint_dir)

without using ray.train.report.

we can leave this for user.

Comment on lines +57 to +66
restore_args = jax.tree_util.tree_map(
lambda x: type_handlers.ArrayRestoreArgs(
mesh=x.sharding.mesh, sharding=x.sharding
)
if isinstance(x, (jax.Array, jax.ShapeDtypeStruct))
and hasattr(x, "sharding")
else ocp.checkpoint_utils.construct_restore_args(x),
target,
is_leaf=lambda x: isinstance(x, (jax.Array, jax.ShapeDtypeStruct)),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my understanding,
is the main difference here to pass in target (which would include both model definition and sharding info)
so that we can restore from a previous checkpoint?

Would it be sufficient we keep the Mesh/Sharding in the training context so that user can just use that for restoring? In this way, I think we only need a util to construct a restore_args right

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants