Skip to content

Commit 2ccfc8c

Browse files
fcharrasbetatimogrisel
committed
Engine plugin API and engine entry point for Lloyd's KMeans
Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Franck Charras <29153872+fcharras@users.noreply.github.com>
1 parent 59e5070 commit 2ccfc8c

14 files changed

Lines changed: 1046 additions & 68 deletions

File tree

doc/computing.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Computing with scikit-learn
1414
computing/scaling_strategies
1515
computing/computational_performance
1616
computing/parallelism
17+
computing/engine

doc/computing/engine.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
.. Places parent toc into the sidebar
2+
3+
:parenttoc: True
4+
5+
.. _engine:
6+
7+
Computation Engines (experimental)
8+
==================================
9+
10+
**This API is experimental** which means that it is subject to change without
11+
any backward compatibility guarantees.
12+
13+
TODO: explain goals here
14+
15+
Activating an engine
16+
--------------------
17+
18+
TODO: installing third party engine provider packages
19+
20+
TODO: how to list installed engines
21+
22+
TODO: how to install a plugin
23+
24+
Writing a new engine provider
25+
-----------------------------
26+
27+
TODO: show engine API of a given estimator.
28+
29+
TODO: give example setup.py with setuptools to define an entrypoint.

doc/whats_new/v1.4.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ Changes impacting all modules
4141
to work with our estimators and functions.
4242
:pr:`26464` by `Thomas Fan`_.
4343

44+
- |Enhancement| Experimental engine API (no backward compatibility guarantees)
45+
to allow for external packages to contribute alternative implementations for
46+
the core computational routines of some selected scikit-learn estimators.
47+
48+
Currently, the following estimators allow alternative implementations:
49+
50+
- :class:`~sklearn.cluster.KMeans` (only for the LLoyd algorithm).
51+
- TODO: add more when available.
52+
53+
External engine providers include:
54+
55+
- https://github.com/soda-inria/sklearn-numba-dpex that provided a KMeans
56+
engine optimized for OpenCL enabled GPUs.
57+
- TODO: add more here
58+
59+
:pr:`25535` by :user:`ogrisel`, :user:`fcharras` and :user:`betatim`.
60+
61+
4462
Changelog
4563
---------
4664

setup.py

100755100644
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def setup_package():
600600
python_requires=python_requires,
601601
install_requires=min_deps.tag_to_packages["install"],
602602
package_data={"": ["*.csv", "*.gz", "*.txt", "*.pxd", "*.rst", "*.jpg"]},
603+
entry_points={"pytest11": ["sklearn_plugin_testing = sklearn._engine.testing"]},
603604
zip_safe=False, # the package can run out of an .egg file
604605
extras_require={
605606
key: min_deps.tag_to_packages[key]

sklearn/_config.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Global configuration state and functions for management
22
"""
3+
import inspect
34
import os
45
import threading
56
from contextlib import contextmanager as contextmanager
@@ -14,6 +15,8 @@
1415
),
1516
"enable_cython_pairwise_dist": True,
1617
"array_api_dispatch": False,
18+
"engine_provider": (),
19+
"engine_attributes": "engine_types",
1720
"transform_output": "default",
1821
"enable_metadata_routing": False,
1922
"skip_parameter_validation": False,
@@ -55,6 +58,8 @@ def set_config(
5558
pairwise_dist_chunk_size=None,
5659
enable_cython_pairwise_dist=None,
5760
array_api_dispatch=None,
61+
engine_provider=None,
62+
engine_attributes=None,
5863
transform_output=None,
5964
enable_metadata_routing=None,
6065
skip_parameter_validation=None,
@@ -126,6 +131,26 @@ def set_config(
126131
127132
.. versionadded:: 1.2
128133
134+
engine_provider : str or sequence of {str, engine class}, default=None
135+
Specify list of enabled computational engine implementations provided
136+
by third party packages. Engines are enabled by listing the name of
137+
the provider or listing an engine class directly.
138+
139+
See the :ref:`User Guide <engine>` for more details.
140+
141+
.. versionadded:: 1.4
142+
143+
engine_attributes : str, default=None
144+
Enable conversion of estimator attributes to scikit-learn native
145+
types by setting to "sklearn_types". By default attributes are
146+
stored using engine native types. This avoids additional conversions
147+
and memory transfers between host and device when calling `predict`/
148+
`transform` after `fit` of an engine-aware estimator.
149+
150+
See the :ref:`User Guide <engine>` for more details.
151+
152+
.. versionadded:: 1.4
153+
129154
transform_output : str, default=None
130155
Configure output of `transform` and `fit_transform`.
131156
@@ -185,6 +210,18 @@ def set_config(
185210

186211
_check_array_api_dispatch(array_api_dispatch)
187212
local_config["array_api_dispatch"] = array_api_dispatch
213+
if engine_provider is not None:
214+
# Single provider name was passed in
215+
if isinstance(engine_provider, str):
216+
engine_provider = (engine_provider,)
217+
# Allow direct registration of engine classes to ease testing, debugging
218+
# and benchmarking without having to register a fake package with metadata
219+
# just to use a custom engine not meant to be used by end-users.
220+
elif inspect.isclass(engine_provider):
221+
engine_provider = (engine_provider,)
222+
local_config["engine_provider"] = engine_provider
223+
if engine_attributes is not None:
224+
local_config["engine_attributes"] = engine_attributes
188225
if transform_output is not None:
189226
local_config["transform_output"] = transform_output
190227
if enable_metadata_routing is not None:
@@ -203,6 +240,8 @@ def config_context(
203240
pairwise_dist_chunk_size=None,
204241
enable_cython_pairwise_dist=None,
205242
array_api_dispatch=None,
243+
engine_provider=None,
244+
engine_attributes=None,
206245
transform_output=None,
207246
enable_metadata_routing=None,
208247
skip_parameter_validation=None,
@@ -273,6 +312,24 @@ def config_context(
273312
274313
.. versionadded:: 1.2
275314
315+
engine_provider : str or sequence of {str, engine class}, default=None
316+
Specify list of enabled computational engine implementations provided
317+
by third party packages. Engines are enabled by listing the name of
318+
the provider or listing an engine class directly.
319+
320+
See the :ref:`User Guide <engine>` for more details.
321+
322+
.. versionadded:: 1.4
323+
324+
engine_attributes : str, default=None
325+
Enable conversion of estimator attributes to scikit-learn native
326+
types by setting to "sklearn_types". By default attributes are
327+
stored using engine native types.
328+
329+
See the :ref:`User Guide <engine>` for more details.
330+
331+
.. versionadded:: 1.4
332+
276333
transform_output : str, default=None
277334
Configure output of `transform` and `fit_transform`.
278335
@@ -344,6 +401,8 @@ def config_context(
344401
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
345402
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
346403
array_api_dispatch=array_api_dispatch,
404+
engine_provider=engine_provider,
405+
engine_attributes=engine_attributes,
347406
transform_output=transform_output,
348407
enable_metadata_routing=enable_metadata_routing,
349408
skip_parameter_validation=skip_parameter_validation,

sklearn/_engine/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import convert_attributes, get_engine_classes, list_engine_provider_names
2+
3+
__all__ = ["convert_attributes", "get_engine_classes", "list_engine_provider_names"]

sklearn/_engine/base.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import inspect
2+
import warnings
3+
from functools import lru_cache, wraps
4+
from importlib import import_module
5+
from importlib.metadata import entry_points
6+
7+
from sklearn._config import get_config
8+
9+
SKLEARN_ENGINES_ENTRY_POINT = "sklearn_engines"
10+
11+
12+
class EngineSpec:
13+
__slots__ = ["name", "provider_name", "module_name", "engine_qualname"]
14+
15+
def __init__(self, name, provider_name, module_name, engine_qualname):
16+
self.name = name
17+
self.provider_name = provider_name
18+
self.module_name = module_name
19+
self.engine_qualname = engine_qualname
20+
21+
def get_engine_class(self):
22+
engine = import_module(self.module_name)
23+
for attr in self.engine_qualname.split("."):
24+
engine = getattr(engine, attr)
25+
return engine
26+
27+
28+
def _parse_entry_point(entry_point):
29+
module_name, engine_qualname = entry_point.value.split(":")
30+
provider_name = next(iter(module_name.split(".", 1)))
31+
return EngineSpec(entry_point.name, provider_name, module_name, engine_qualname)
32+
33+
34+
@lru_cache
35+
def _parse_entry_points(provider_names=None):
36+
specs = []
37+
all_entry_points = entry_points()
38+
if hasattr(all_entry_points, "select"):
39+
engine_entry_points = all_entry_points.select(group=SKLEARN_ENGINES_ENTRY_POINT)
40+
else:
41+
engine_entry_points = all_entry_points.get(SKLEARN_ENGINES_ENTRY_POINT, ())
42+
for entry_point in engine_entry_points:
43+
try:
44+
spec = _parse_entry_point(entry_point)
45+
if provider_names is not None and spec.provider_name not in provider_names:
46+
# Skip entry points that do not match the requested provider names.
47+
continue
48+
specs.append(spec)
49+
except Exception as e:
50+
# Do not raise an exception in case an invalid package has been
51+
# installed in the same Python env as scikit-learn: just warn and
52+
# skip.
53+
warnings.warn(
54+
f"Invalid {SKLEARN_ENGINES_ENTRY_POINT} entry point"
55+
f" {entry_point.name} with value {entry_point.value}: {e}"
56+
)
57+
if provider_names is not None:
58+
observed_provider_names = {spec.provider_name for spec in specs}
59+
missing_providers = set(provider_names) - observed_provider_names
60+
if missing_providers:
61+
raise RuntimeError(
62+
"Could not find any provider for the"
63+
f" {SKLEARN_ENGINES_ENTRY_POINT} entry point with name(s):"
64+
f" {', '.join(repr(p) for p in sorted(missing_providers))}"
65+
)
66+
return specs
67+
68+
69+
def list_engine_provider_names():
70+
"""Find the list of sklearn_engine provider names
71+
72+
This function only inspects the metadata and should trigger any module import.
73+
"""
74+
return sorted({spec.provider_name for spec in _parse_entry_points()})
75+
76+
77+
def _get_engine_classes(engine_name, provider_names, engine_specs, default):
78+
specs_by_provider = {}
79+
for spec in engine_specs:
80+
if spec.name != engine_name:
81+
continue
82+
specs_by_provider.setdefault(spec.provider_name, spec)
83+
84+
for provider_name in provider_names:
85+
if inspect.isclass(provider_name):
86+
# The provider name is actually a ready-to-go engine class.
87+
# Instead of a made up string to name this ad-hoc provider
88+
# we use the class itself. This mirrors what the user used
89+
# when they set the config (ad-hoc class or string naming
90+
# a provider).
91+
engine_class = provider_name
92+
if getattr(engine_class, "engine_name", None) != engine_name:
93+
continue
94+
yield engine_class, engine_class
95+
96+
spec = specs_by_provider.get(provider_name)
97+
if spec is not None:
98+
yield spec.provider_name, spec.get_engine_class()
99+
100+
yield "default", default
101+
102+
103+
def get_engine_classes(engine_name, default, verbose=False):
104+
"""Find all possible providers of `engine_name`.
105+
106+
Provider candidates are found based on parsing entrypoint definitions that
107+
match the name of enabled engine providers, as well as, ad-hoc providers
108+
in the form of engine classes in the list of enabled engine providers.
109+
110+
Parameters
111+
----------
112+
engine_name : str
113+
The name of the algorithm for which to find engine classes.
114+
115+
default : class
116+
The default engine class to use if no other provider is found.
117+
118+
verbose : bool, default=False
119+
If True, print the name of the engine classes that are tried.
120+
121+
Yields
122+
------
123+
provider : str or class
124+
The "name" of each matching provider. The "name" corresponds to the
125+
entry in the `engine_provider` configuration. It can be a string or a
126+
class for programmatically registered ad-hoc providers.
127+
128+
engine_class :
129+
The engine class that implements the algorithm for the given provider.
130+
"""
131+
provider_names = get_config()["engine_provider"]
132+
133+
if not provider_names:
134+
yield "default", default
135+
return
136+
137+
engine_specs = _parse_entry_points(
138+
provider_names=tuple(
139+
[name for name in provider_names if not inspect.isclass(name)]
140+
)
141+
)
142+
for provider, engine_class in _get_engine_classes(
143+
engine_name=engine_name,
144+
provider_names=provider_names,
145+
engine_specs=engine_specs,
146+
default=default,
147+
):
148+
if verbose:
149+
print(
150+
f"trying engine {engine_class.__module__}.{engine_class.__qualname__}."
151+
)
152+
yield provider, engine_class
153+
154+
155+
def convert_attributes(method):
156+
"""Convert estimator attributes after calling the decorated method.
157+
158+
The attributes of an estimator can be stored in "engine native" types
159+
(default) or "scikit-learn native" types. This decorator will call the
160+
engine's conversion function when needed. Use this decorator on methods
161+
that set estimator attributes.
162+
"""
163+
164+
@wraps(method)
165+
def wrapper(self, *args, **kwargs):
166+
r = method(self, *args, **kwargs)
167+
convert_attributes = get_config()["engine_attributes"]
168+
169+
if convert_attributes == "sklearn_types":
170+
engine = self._engine_class
171+
for name, value in vars(self).items():
172+
# All attributes are passed to the engine, which can
173+
# either convert the value (engine specific types) or
174+
# return it as is (native Python types)
175+
converted = engine.convert_to_sklearn_types(name, value)
176+
setattr(self, name, converted)
177+
178+
# No matter which engine was used to fit, after the attribute
179+
# conversion to the sklearn native types the default engine
180+
# is used.
181+
self._engine_class = self._default_engine
182+
self._engine_provider = "default"
183+
184+
return r
185+
186+
return wrapper

0 commit comments

Comments
 (0)