Ensure all autologging callbacks are picklable#5039
Conversation
multiprocessingmultiprocessing
9a9a895 to
7618e83
Compare
Signed-off-by: harupy <hkawamura0130@gmail.com>
2f0ffed to
61bd6ad
Compare
|
Fastai doesn't seem to support distributed training using CPUs: |
|
Distributed training in MXnet seems pretty complex: https://mxnet.apache.org/versions/1.8.0/api/faq/distributed_training |
|
For sklearn, we already have a test for parallelised training: mlflow/tests/sklearn/test_sklearn_autolog.py Lines 732 to 733 in d6ae841 Is there anything else that we need to test for scikit-learn? |
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
multiprocessing
Does this test case use a multiprocess backend? |
|
@dbczumar Thanks for the comment!
We run the test using mlflow/tests/sklearn/test_sklearn_autolog.py Lines 715 to 733 in d6ae841 |
Sounds good! |
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Done! |
Signed-off-by: harupy <hkawamura0130@gmail.com>
9c27efe to
6d2c7ef
Compare
| def picklable_exception_safe_function(function): | ||
| """ | ||
| Wraps the specified function with broad exception handling to guard | ||
| against unexpected errors during autologging while preserving picklability. | ||
| """ | ||
| if is_testing(): | ||
| setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True) | ||
|
|
||
| return update_wrapper_extended(functools.partial(_safe_function, function), function) |
There was a problem hiding this comment.
@dbczumar I tried using picklable_exception_safe_function in _ExceptionSafeClass:
diff --git a/mlflow/utils/autologging_utils/safety.py b/mlflow/utils/autologging_utils/safety.py
index 11ffde645..4b15e38ea 100644
--- a/mlflow/utils/autologging_utils/safety.py
+++ b/mlflow/utils/autologging_utils/safety.py
@@ -96,7 +96,7 @@ def _exception_safe_class_factory(base_class):
class _ExceptionSafeClass(base_class):
def __new__(cls, name, bases, dct):
for m in dct:
# class methods or static methods are not callable.
if callable(dct[m]):
- dct[m] = exception_safe_function(dct[m])
+ dct[m] = picklable_exception_safe_function(dct[m])
return base_class.__new__(cls, name, bases, dct)
return _ExceptionSafeClassbut this gave:
% pytest tests/xgboost/test_xgboost_autolog.py -k is_pickable
============================================================ test session starts =============================================================
platform linux -- Python 3.7.10, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 -- /home/haru/miniconda3/envs/mlflow-dev-env/bin/python
cachedir: .pytest_cache
rootdir: /home/haru/Desktop/repositories/mlflow, configfile: pytest.ini
collected 29 items / 27 deselected / 2 selected
tests/xgboost/test_xgboost_autolog.py::test_callback_func_is_pickable PASSED [ 50%]
tests/xgboost/test_xgboost_autolog.py::test_callback_class_is_pickable FAILED [100%]
================================================================== FAILURES ==================================================================
______________________________________________________ test_callback_class_is_pickable _______________________________________________________
@pytest.mark.skipif(
Version(xgb.__version__.replace("SNAPSHOT", "dev")) < Version("1.3.0"),
reason="`xgboost.callback.TrainingCallback` is not supported",
)
def test_callback_class_is_pickable():
from mlflow.xgboost._autolog import AutologCallback
> cb = AutologCallback(BatchMetricsLogger(run_id="1234"), eval_results={})
AutologCallback = <class 'mlflow.xgboost._autolog.AutologCallback'>
tests/xgboost/test_xgboost_autolog.py:577:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
function = <function AutologCallback.__init__ at 0x7fd1c2699dd0>
args = (<mlflow.utils.autologging_utils.BatchMetricsLogger object at 0x7fd19e0ac890>,), kwargs = {'eval_results': {}}
def _safe_function(function, *args, **kwargs):
try:
> return function(*args, **kwargs)
E TypeError: __init__() missing 1 required positional argument: 'metrics_logger'
args = (<mlflow.utils.autologging_utils.BatchMetricsLogger object at 0x7fd19e0ac890>,)
function = <function AutologCallback.__init__ at 0x7fd1c2699dd0>
kwargs = {'eval_results': {}}
mlflow/utils/autologging_utils/safety.py:56: TypeErrorIt appears that functools.partial alters the __init__ behavior. I'm investigating a workaround.
There was a problem hiding this comment.
@harupy This solution makes sense once we can resolve the __init__ issues. Let's also make sure to remove exception_safe_function and replace it with picklable_exception_safe_function.
There was a problem hiding this comment.
This SO post says:
partial objects are like function objects in that they are callable, weak referencable, and can have attributes. There are some important differences. For instance, the name and doc attributes are not created automatically. Also, partial objects defined in classes behave like static methods and do not transform into bound methods during instance attribute look-up.
This SO post suggests using functools.partialmethod, but this makes a class non-callable:
__________________________________________________________________________________ test_callback_class_is_pickable __________________________________________________________________________________
@pytest.mark.skipif(
not IS_TRAINING_CALLBACK_SUPPORTED,
reason="`xgboost.callback.TrainingCallback` is not supported",
)
def test_callback_class_is_pickable():
from mlflow.xgboost._autolog import AutologCallback
> cb = AutologCallback(BatchMetricsLogger(run_id="1234"), eval_results={})
AutologCallback = <class 'mlflow.xgboost._autolog.AutologCallback'>
tests/xgboost/test_xgboost_autolog.py:577:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
function = <mlflow.xgboost._autolog.AutologCallback object at 0x7f8a8bb92850>
args = (<function AutologCallback.__init__ at 0x7f8ab0311050>, <mlflow.utils.autologging_utils.BatchMetricsLogger object at 0x7f8a8bcf04d0>), kwargs = {'eval_results': {}}
def _safe_function(function, *args, **kwargs):
try:
> return function(*args, **kwargs)
E TypeError: 'AutologCallback' object is not callable
args = (<function AutologCallback.__init__ at 0x7f8ab0311050>, <mlflow.utils.autologging_utils.BatchMetricsLogger object at 0x7f8a8bcf04d0>)
function = <mlflow.xgboost._autolog.AutologCallback object at 0x7f8a8bb92850>
kwargs = {'eval_results': {}}Git diff:
diff --git a/mlflow/utils/autologging_utils/safety.py b/mlflow/utils/autologging_utils/safety.py
index 11ffde645..e76932f7c 100644
--- a/mlflow/utils/autologging_utils/safety.py
+++ b/mlflow/utils/autologging_utils/safety.py
@@ -72,6 +72,17 @@ def picklable_exception_safe_function(function):
return update_wrapper_extended(functools.partial(_safe_function, function), function)
+def picklable_exception_safe_method(function):
+ """
+ Wraps the specified function with broad exception handling to guard
+ against unexpected errors during autologging while preserving picklability.
+ """
+ if is_testing():
+ setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
+
+ return update_wrapper_extended(functools.partialmethod(_safe_function, function), function)
+
+
def _exception_safe_class_factory(base_class):
"""
Creates an exception safe metaclass that inherits from `base_class`.
@@ -96,7 +107,7 @@ def _exception_safe_class_factory(base_class):
for m in dct:
# class methods or static methods are not callable.
if callable(dct[m]):
- dct[m] = exception_safe_function(dct[m])
+ dct[m] = picklable_exception_safe_method(dct[m])
return base_class.__new__(cls, name, bases, dct)
return _ExceptionSafeClass
There was a problem hiding this comment.
Got it. Let's ignore https://github.com/mlflow/mlflow/pull/5039/files#r750767771 and use a separate function for methods defined on classes :). Thanks @harupy !
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: harupy hkawamura0130@gmail.com
What changes are proposed in this pull request?
Ensure all autologging callbacks are picklable.
How is this patch tested?
Unit tests
Does this PR change the documentation?
ci/circleci: build_doccheck. If it's successful, proceed to thenext step, otherwise fix it.
Detailson the right to open the job page of CircleCI.Artifactstab.docs/build/html/index.html.Release Notes
Is this a user-facing change?
(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts: Artifact stores and artifact loggingarea/build: Build and test infrastructure for MLflowarea/docs: MLflow documentation pagesarea/examples: Example codearea/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models: MLmodel format, model serialization/deserialization, flavorsarea/projects: MLproject format, project running backendsarea/scoring: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra: MLflow Tracking server backendarea/tracking: Tracking Service, tracking client APIs, autologgingInterface
area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows: Windows supportLanguage
language/r: R APIs and clientslanguage/java: Java APIs and clientslanguage/new: Proposals for new client languagesIntegrations
integrations/azure: Azure and Azure ML integrationsintegrations/sagemaker: SageMaker integrationsintegrations/databricks: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change- The PR will be mentioned in the "Breaking Changes" sectionrn/none- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature- A new user-facing feature worth mentioning in the release notesrn/bug-fix- A user-facing bug fix worth mentioning in the release notesrn/documentation- A user-facing documentation change worth mentioning in the release notes