Skip to content

[train] XGBoostTrainer and LightGBMTrainer API revamps #50042

@justinvyu

Description

@justinvyu

Context

We are deprecating the legacy XGBoostTrainer and LightGBMTrainer APIs in favor of a new, more flexible DataParallelTrainer API. This migration unifies the training interface across frameworks, which takes in a custom training function.

Train V2 only supports this new API.

⭐ Motivation

  • Improve flexibility for training logic
    • The old API wrapped a limited subset of the XGBoost/LightGBM APIs. For example, users previously could not use xgboost's scikit-learn interface in order to use scikit-learn's cross validation utility.
    • The new design allows full access to the native APIs (xgboost.train, xgboost.XGBRegressor, lightgbm.train, etc.), making it easier to adopt new upstream features.
  • Increase transparency
    • The new API clearly shows the separation of responsibility between Ray Train’s distributed XGBoost setup logic and user training code.
    • The new API also makes data ingestion clearer, since the user now explicitly triggers the dataset materialization and wraps it in a framework native DMatrix (or Dataset for LightGBM).
  • Simplify migration from non-distributed code
    • Users can now wrap their existing single-node training functions in XGBoostTrainer or LightGBMTrainer with minimal changes.
  • Remove deprecated dependencies
    • The old trainers relied on xgboost_ray / lightgbm_ray external packages, which are no longer maintained.

🚀 Quickstart

Find runnable examples here, which showcase the new APIs:

🔁 Migration guide

We show the XGBoost migration examples, but the same migrations apply for the LightGBMTrainer with the corresponding LightGBM APIs.

V1 Argument Migration Strategy (V2) Example
datasets Continue to pass your Ray Datasets into the Trainer, but you should now access them from your training function with ray.train.get_dataset_shard("train") and convert them to an input data format that is compatible with XGBoost/LightGBM. ds_shard = ray.train.get_dataset_shard("train")
shard_df = ds_shard.materialize().to_pandas()
dtrain = xgboost.DMatrix(shard_df)
label_column Define in your training function and drop the column from your input features. X, y = df.drop("target", 1), df["target"]
params Pass directly into the framework’s training API (xgboost.train / lightgbm.train). xgboost.train(params, ...)
dmatrix_params Pass into the xgboost.DMatrix() constructor. dtrain = xgboost.DMatrix(..., **dmatrix_params)
num_boost_round Pass as a positional argument to the framework training API. xgboost.train(..., num_boost_round=100)
**train_kwargs Pass directly into the framework API call. xgboost.train(..., **train_kwargs)

LightGBM specific migrations

You need to append ray.train.lightgbm.get_network_params() to your LightGBM parameters.

See the quickstart for a full example.

References

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions