-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[train] XGBoostTrainer and LightGBMTrainer API revamps #50042
Copy link
Copy link
Open
Labels
Description
Context
We are deprecating the legacy XGBoostTrainer and LightGBMTrainer APIs in favor of a new, more flexible DataParallelTrainer API. This migration unifies the training interface across frameworks, which takes in a custom training function.
Train V2 only supports this new API.
⭐ Motivation
- Improve flexibility for training logic
- The old API wrapped a limited subset of the XGBoost/LightGBM APIs. For example, users previously could not use
xgboost'sscikit-learninterface in order to usescikit-learn's cross validation utility. - The new design allows full access to the native APIs (
xgboost.train,xgboost.XGBRegressor,lightgbm.train, etc.), making it easier to adopt new upstream features.
- The old API wrapped a limited subset of the XGBoost/LightGBM APIs. For example, users previously could not use
- Increase transparency
- The new API clearly shows the separation of responsibility between Ray Train’s distributed XGBoost setup logic and user training code.
- The new API also makes data ingestion clearer, since the user now explicitly triggers the dataset materialization and wraps it in a framework native
DMatrix(orDatasetfor LightGBM).
- Simplify migration from non-distributed code
- Users can now wrap their existing single-node training functions in
XGBoostTrainerorLightGBMTrainerwith minimal changes.
- Users can now wrap their existing single-node training functions in
- Remove deprecated dependencies
- The old trainers relied on
xgboost_ray/lightgbm_rayexternal packages, which are no longer maintained.
- The old trainers relied on
🚀 Quickstart
Find runnable examples here, which showcase the new APIs:
🔁 Migration guide
We show the XGBoost migration examples, but the same migrations apply for the LightGBMTrainer with the corresponding LightGBM APIs.
| V1 Argument | Migration Strategy (V2) | Example |
|---|---|---|
datasets |
Continue to pass your Ray Datasets into the Trainer, but you should now access them from your training function with ray.train.get_dataset_shard("train") and convert them to an input data format that is compatible with XGBoost/LightGBM. |
ds_shard = ray.train.get_dataset_shard("train") shard_df = ds_shard.materialize().to_pandas()dtrain = xgboost.DMatrix(shard_df) |
label_column |
Define in your training function and drop the column from your input features. | X, y = df.drop("target", 1), df["target"] |
params |
Pass directly into the framework’s training API (xgboost.train / lightgbm.train). |
xgboost.train(params, ...) |
dmatrix_params |
Pass into the xgboost.DMatrix() constructor. |
dtrain = xgboost.DMatrix(..., **dmatrix_params) |
num_boost_round |
Pass as a positional argument to the framework training API. | xgboost.train(..., num_boost_round=100) |
**train_kwargs |
Pass directly into the framework API call. | xgboost.train(..., **train_kwargs) |
LightGBM specific migrations
You need to append ray.train.lightgbm.get_network_params() to your LightGBM parameters.
See the quickstart for a full example.
References
Reactions are currently unavailable