Autologging functionality for scikit-learn integration with XGBoost (Part 2)#5055
Autologging functionality for scikit-learn integration with XGBoost (Part 2)#5055jwyyy wants to merge 7 commits into
Conversation
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
| mlflow.sklearn.log_model(sk_model, "sk_models") | ||
| """ | ||
| return Model.log( | ||
| Model.log( |
There was a problem hiding this comment.
It seems Model.log() doesn't return any value. Maybe we can remove return.
|
|
||
|
|
||
| def _autolog( | ||
| flavor_name=FLAVOR_NAME, |
There was a problem hiding this comment.
Internal API for sklearn autologging. The flavor_name field allows mlflow.xgboost to specify the xgboost_sklearn flavor, preventing flavor conflict with mlflow.sklearn.
|
|
||
| def _mlflow_xgboost_logging( | ||
| importance_types, autologging_client, logger, original, sklearn_estimator, *args, **kwargs, | ||
| ): |
There was a problem hiding this comment.
Re-organize early stopping call backs and feature importance plot. This function is re-used in mlflow.sklearn for logging XGBoost sklearn estimators.
|
|
||
| safe_patch_call_count = ( | ||
| safe_patch_mock.call_count + xgb_sklearn_safe_patch_mock.call_count | ||
| ) |
There was a problem hiding this comment.
Since mlflow.sklearn._autolog() is called inside mlflow.xgboost, we need to count safe_patch called due to enabling sklearn autologging.
|
Hi @harupy @dbczumar, I made a new PR to complete the autologging functionality for XGBoost sklearn estimators. It is based on our previous discussion #4885. I left a few comments in the PR to highlight the changes:
Please correct me if I missed anything. Also please let me know your feedback and suggestions! Thanks a lot! |
|
Regarding the tests, I was trying to integrate XGBoost sklearn estimators tests with the existing tests: change # current
expected_params = {"num_boost_round": 20, "early_stopping_rounds": 5, "verbose_eval": False}
xgb.train(bst_params, dtrain, evals=[(dtrain, "train")], **expected_params)to something like # new
def xgb_train(mode, bst_params, data, other_kwargs):
if mode == "xgboost_sklearn":
# return XGBoost sklearn model using bst_params and other_kwargs
else:
# mode == "xgboost"
# return xgb.train(...)
# insider a test function
xgb_train(mode, bst_params, data, other_kwargs)but the integration could be messy in this way. Not all parameters passed to xgboost.train(bst_params, dtrain, **kwargs)but xgb_sklearn_model = xgboost.XGBClassifier(**bst_params, **kwargs)
xgb_sklearn_model.fit(X, y) # X, y from dtraingenerally is not error proof. Here is an example:
xgb_classifier = xgb.XGBClassifier(objective="multi:softprob", num_class=3, n_estimators=20)
xgb_classifier.fit(X, y, eval_metric=["merror", "mlogloss"], eval_set=[(X1,y1),(X2,y2)])@harupy @dbczumar Should we keep doing the integration approach? Or is it a better idea to create new separate tests for XGBoost sklearn models? What are your opinions / suggestions? Thanks! |
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
What changes are proposed in this pull request?
This is the second PR to add autologging for XGBoost sklearn models using
mlflow.sklearnautologging routine.(Previous PR: #4954)
(Draft + discussion: #4885)
How is this patch tested?
A new example is provided. Tests will be added later.
Does this PR change the documentation?
ci/circleci: build_doccheck. If it's successful, proceed to thenext step, otherwise fix it.
Detailson the right to open the job page of CircleCI.Artifactstab.docs/build/html/index.html.Release Notes
Is this a user-facing change?
Success merge of this PR will enable autologging for XGBoost scikit-learn models.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts: Artifact stores and artifact loggingarea/build: Build and test infrastructure for MLflowarea/docs: MLflow documentation pagesarea/examples: Example codearea/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models: MLmodel format, model serialization/deserialization, flavorsarea/projects: MLproject format, project running backendsarea/scoring: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra: MLflow Tracking server backendarea/tracking: Tracking Service, tracking client APIs, autologgingInterface
area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows: Windows supportLanguage
language/r: R APIs and clientslanguage/java: Java APIs and clientslanguage/new: Proposals for new client languagesIntegrations
integrations/azure: Azure and Azure ML integrationsintegrations/sagemaker: SageMaker integrationsintegrations/databricks: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change- The PR will be mentioned in the "Breaking Changes" sectionrn/none- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature- A new user-facing feature worth mentioning in the release notesrn/bug-fix- A user-facing bug fix worth mentioning in the release notesrn/documentation- A user-facing documentation change worth mentioning in the release notes