-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
FEA add CumulativeAccuracyDisplay #28752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2630953
5b42e01
cc674b0
aec617f
73ca739
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,313 @@ | ||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from ...utils._plotting import _BinaryClassifierCurveDisplayMixin | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class CumulativeAccuracyDisplay(_BinaryClassifierCurveDisplayMixin): | ||||||||||||||||||||||||||||
| """Cumulative Accuracy Profile (CAP) Curve visualization. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| It is recommended to use | ||||||||||||||||||||||||||||
| :func:`~sklearn.metrics.CumulativeAccuracyDisplay.from_estimator` or | ||||||||||||||||||||||||||||
| :func:`~sklearn.metrics.CumulativeAccuracyDisplay.from_predictions` to create | ||||||||||||||||||||||||||||
| a :class:`~sklearn.metrics.CumulativeAccuracyDisplay`. All parameters are | ||||||||||||||||||||||||||||
| stored as attributes. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Read more in the :ref:`User Guide <visualizations>`. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||
| cumulative_true_positives : ndarray | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| Cumulative number of true positives. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| cumulative_total : ndarray | ||||||||||||||||||||||||||||
| Cumulative number of cases examined. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| estimator_name : str, default=None | ||||||||||||||||||||||||||||
| Name of estimator. If None, the estimator name is not shown. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| pos_label : int, float, bool or str, default=None | ||||||||||||||||||||||||||||
| The class considered as the positive class when computing the metrics. | ||||||||||||||||||||||||||||
| By default, `estimators.classes_[1]` is considered as the positive class. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Attributes | ||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||
| line_ : matplotlib Artist | ||||||||||||||||||||||||||||
| CAP Curve. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ax_ : matplotlib Axes | ||||||||||||||||||||||||||||
| Axes with CAP Curve. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| figure_ : matplotlib Figure | ||||||||||||||||||||||||||||
| Figure containing the curve. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||
| cumulative_true_positives, | ||||||||||||||||||||||||||||
| cumulative_total, | ||||||||||||||||||||||||||||
| estimator_name=None, | ||||||||||||||||||||||||||||
| pos_label=None, | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| self.estimator_name = estimator_name | ||||||||||||||||||||||||||||
| self.cumulative_true_positives = cumulative_true_positives | ||||||||||||||||||||||||||||
| self.cumulative_total = cumulative_total | ||||||||||||||||||||||||||||
| self.pos_label = pos_label | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def plot( | ||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||
| ax=None, | ||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||
| normalize_scale=False, | ||||||||||||||||||||||||||||
| name=None, | ||||||||||||||||||||||||||||
| plot_chance_level=False, | ||||||||||||||||||||||||||||
| chance_level_kw=None, | ||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| """Plot visualization. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Extra keyword arguments will be passed to matplotlib's ``plot``. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||
| ax : matplotlib axes, default=None | ||||||||||||||||||||||||||||
| Axes object to plot on. If `None`, a new figure and axes is created. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| name : str, default=None | ||||||||||||||||||||||||||||
| Name of CAP Curve for labeling. If `None`, use `estimator_name` if | ||||||||||||||||||||||||||||
| not `None`, otherwise no labeling is shown. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| **kwargs : dict | ||||||||||||||||||||||||||||
| Keyword arguments to be passed to matplotlib's `plot`. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||
| display : :class:`~sklearn.metrics.CumulativeAccuracyDisplay` | ||||||||||||||||||||||||||||
| Object that stores computed values. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if normalize_scale: | ||||||||||||||||||||||||||||
| self.cumulative_true_positives = ( | ||||||||||||||||||||||||||||
| self.cumulative_true_positives / self.cumulative_true_positives[-1] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| self.cumulative_total = self.cumulative_total / self.cumulative_total[-1] | ||||||||||||||||||||||||||||
| self.ax_.set_xlim(0, 1) | ||||||||||||||||||||||||||||
| self.ax_.set_ylim(0, 1) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| line_kwargs = {"label": name} if name is not None else {} | ||||||||||||||||||||||||||||
| line_kwargs.update(**kwargs) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| chance_level_line_kw = { | ||||||||||||||||||||||||||||
| "label": "Random Prediction", | ||||||||||||||||||||||||||||
| "color": "k", | ||||||||||||||||||||||||||||
| "linestyle": "--", | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if chance_level_kw is not None: | ||||||||||||||||||||||||||||
| chance_level_line_kw.update(**chance_level_kw) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| (self.line_,) = self.ax_.plot( | ||||||||||||||||||||||||||||
| self.cumulative_total, self.cumulative_true_positives, **line_kwargs | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if plot_chance_level: | ||||||||||||||||||||||||||||
| (self.chance_level_,) = self.ax_.plot( | ||||||||||||||||||||||||||||
| (0, 1), (0, 1), **chance_level_line_kw | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| self.chance_level_ = None | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| xlabel = "Total Cases Examined" | ||||||||||||||||||||||||||||
| ylabel = "Cumulative True Positives" | ||||||||||||||||||||||||||||
| self.ax_.set(xlabel=xlabel, ylabel=ylabel, aspect="equal") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if "label" in line_kwargs: | ||||||||||||||||||||||||||||
| self.ax_.legend(loc="lower right") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||
| def from_predictions( | ||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||
| y_true, | ||||||||||||||||||||||||||||
| y_pred, | ||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||
| sample_weight=None, | ||||||||||||||||||||||||||||
| pos_label=None, | ||||||||||||||||||||||||||||
| normalize_scale=False, | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| plot_chance_level=False, | ||||||||||||||||||||||||||||
| name=None, | ||||||||||||||||||||||||||||
| ax=None, | ||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| """Plot the Cumulative Accuracy Profile. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| This is also known as a Gain or Lift Curve for classification, and a Lorenz | ||||||||||||||||||||||||||||
| curve for regression with a positively valued target. | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||
| y_true : array-like of shape (n_samples,) | ||||||||||||||||||||||||||||
| True labels. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| y_pred : array-like of shape (n_samples,) | ||||||||||||||||||||||||||||
| Target scores, can either be probability estimates of the positive | ||||||||||||||||||||||||||||
| class, confidence values, or non-thresholded measure of decisions | ||||||||||||||||||||||||||||
| (as returned by “decision_function” on some classifiers). | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| sample_weight : array-like of shape (n_samples,), default=None | ||||||||||||||||||||||||||||
| Sample weights. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| pos_label : int, float, bool or str, default=None | ||||||||||||||||||||||||||||
| The label of the positive class. When `pos_label=None`, if `y_true` | ||||||||||||||||||||||||||||
| is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an | ||||||||||||||||||||||||||||
| error will be raised. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| name : str, default=None | ||||||||||||||||||||||||||||
| Name of CAP curve for labeling. If `None`, name will be set to | ||||||||||||||||||||||||||||
| `"Classifier"`. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ax : matplotlib axes, default=None | ||||||||||||||||||||||||||||
| Axes object to plot on. If `None`, a new figure and axes is created. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| **kwargs : dict | ||||||||||||||||||||||||||||
| Additional keywords arguments passed to matplotlib `plot` function. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||
| display : :class:`~sklearn.metrics.CumulativeAccuracyDisplay` | ||||||||||||||||||||||||||||
| Object that stores computed values. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| # validate and prepare data | ||||||||||||||||||||||||||||
| if pos_label is None: | ||||||||||||||||||||||||||||
| pos_label = 1 | ||||||||||||||||||||||||||||
| if sample_weight is None: | ||||||||||||||||||||||||||||
| sample_weight = np.ones_like(y_true, dtype=float) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # ensure y_true is boolean for positive class identification | ||||||||||||||||||||||||||||
| y_bool = y_true == pos_label | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # sort predictions and true values based on the predictions | ||||||||||||||||||||||||||||
| sorted_indices = np.argsort(y_pred)[::-1] | ||||||||||||||||||||||||||||
| y_true_sorted = y_bool[sorted_indices] | ||||||||||||||||||||||||||||
| sample_weight_sorted = sample_weight[sorted_indices] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # compute cumulative sums for true positives and all cases | ||||||||||||||||||||||||||||
| cumulative_true_positives = np.cumsum(y_true_sorted * sample_weight_sorted) | ||||||||||||||||||||||||||||
| cumulative_total = np.cumsum(sample_weight_sorted) | ||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very similar to a Lorenz curve that we manually define in this example. scikit-learn/examples/linear_model/plot_poisson_regression_non_normal_loss.py Lines 495 to 507 in 6bf0ba5
The differences are (as far as I can see):
So I think we could indeed reuse the same class both for CAP/Lift/Gain curves and Lorenz curves. To switch between both, we can probably introduce a keyword parameter
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed in the main thread of the PR, let's keep that for a follow-up PR. |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| viz = cls( | ||||||||||||||||||||||||||||
| cumulative_true_positives=cumulative_true_positives, | ||||||||||||||||||||||||||||
| cumulative_total=cumulative_total, | ||||||||||||||||||||||||||||
| estimator_name=name, | ||||||||||||||||||||||||||||
| pos_label=pos_label, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return viz.plot( | ||||||||||||||||||||||||||||
| ax=ax, | ||||||||||||||||||||||||||||
| name=name, | ||||||||||||||||||||||||||||
| normalize_scale=normalize_scale, | ||||||||||||||||||||||||||||
| plot_chance_level=plot_chance_level, | ||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||
| def from_estimator( | ||||||||||||||||||||||||||||
| cls, | ||||||||||||||||||||||||||||
| estimator, | ||||||||||||||||||||||||||||
| X, | ||||||||||||||||||||||||||||
| y, | ||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||
| sample_weight=None, | ||||||||||||||||||||||||||||
| response_method="auto", | ||||||||||||||||||||||||||||
| pos_label=None, | ||||||||||||||||||||||||||||
| normalize_scale=False, | ||||||||||||||||||||||||||||
| plot_chance_level=False, | ||||||||||||||||||||||||||||
| name=None, | ||||||||||||||||||||||||||||
| ax=None, | ||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| """Create the Cumulative Accuracy Profile. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| This is also known as a Gain or Lift Curve for classification, and a Lorenz | ||||||||||||||||||||||||||||
| curve for regression with a positively valued target. | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+233
to
+234
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we agree to leave the Lorenz curve / regression case for a follow-up PR I would defer this change for later. Speaking about regression on an class that is designed and document to only accept classifier predictions would be a source confusion. |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||
| estimator : estimator instance | ||||||||||||||||||||||||||||
| Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` | ||||||||||||||||||||||||||||
| in which the last estimator is a classifier. | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||||||||||||||||||||||||||||
| Input values. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| y : array-like of shape (n_samples,) | ||||||||||||||||||||||||||||
| Target values. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| sample_weight : array-like of shape (n_samples,), default=None | ||||||||||||||||||||||||||||
| Sample weights. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| response_method : {'predict_proba', 'decision_function', 'auto'} \ | ||||||||||||||||||||||||||||
| default='auto' | ||||||||||||||||||||||||||||
| Specifies whether to use :term:`predict_proba` or | ||||||||||||||||||||||||||||
| :term:`decision_function` as the target response. If set to 'auto', | ||||||||||||||||||||||||||||
| :term:`predict_proba` is tried first and if it does not exist | ||||||||||||||||||||||||||||
| :term:`decision_function` is tried next. | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| pos_label : int, float, bool or str, default=None | ||||||||||||||||||||||||||||
| The class considered as the positive class when computing metrics. | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| By default, `estimators.classes_[1]` is considered as the positive class. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| name : str, default=None | ||||||||||||||||||||||||||||
| Name of CAP Curve for labeling. If `None`, use the name of the estimator. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ax : matplotlib axes, default=None | ||||||||||||||||||||||||||||
| Axes object to plot on. If `None`, a new figure and axes is created. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| **kwargs : dict | ||||||||||||||||||||||||||||
| Keyword arguments to be passed to matplotlib's `plot`. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||
| display : :class:`~sklearn.metrics.CumulativeAccuracyDisplay` | ||||||||||||||||||||||||||||
| The CAP Curve display. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # validate and prepare the prediction scores | ||||||||||||||||||||||||||||
| if response_method == "auto": | ||||||||||||||||||||||||||||
| if hasattr(estimator, "predict_proba"): | ||||||||||||||||||||||||||||
| response_method = "predict_proba" | ||||||||||||||||||||||||||||
| elif hasattr(estimator, "decision_function"): | ||||||||||||||||||||||||||||
| response_method = "decision_function" | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||
| "Estimator does not have a predict_proba or decision_function" | ||||||||||||||||||||||||||||
| " method." | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if response_method == "predict_proba": | ||||||||||||||||||||||||||||
JosephBARBIERDARNAL marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
| probabilities = estimator.predict_proba(X) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # assuming positive class is the second column | ||||||||||||||||||||||||||||
| if pos_label is None: | ||||||||||||||||||||||||||||
| pos_label = 1 | ||||||||||||||||||||||||||||
| class_index = np.where(estimator.classes_ == pos_label)[0][0] | ||||||||||||||||||||||||||||
| y_pred = probabilities[:, class_index] | ||||||||||||||||||||||||||||
| elif response_method == "decision_function": | ||||||||||||||||||||||||||||
| y_pred = estimator.decision_function(X) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if name is None: | ||||||||||||||||||||||||||||
| name = estimator.__class__.__name__ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return cls.from_predictions( | ||||||||||||||||||||||||||||
| y_true=y, | ||||||||||||||||||||||||||||
| y_pred=y_pred, | ||||||||||||||||||||||||||||
| sample_weight=sample_weight, | ||||||||||||||||||||||||||||
| name=name, | ||||||||||||||||||||||||||||
| normalize_scale=normalize_scale, | ||||||||||||||||||||||||||||
| plot_chance_level=plot_chance_level, | ||||||||||||||||||||||||||||
| ax=ax, | ||||||||||||||||||||||||||||
| pos_label=pos_label, | ||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from sklearn.datasets import make_classification | ||
| from sklearn.linear_model import LogisticRegression | ||
| from sklearn.model_selection import train_test_split | ||
|
|
||
| from ..cap_curve import CumulativeAccuracyDisplay | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def binary_classification_dataset(): | ||
| X, y = make_classification( | ||
| n_samples=100, | ||
| n_features=2, | ||
| n_informative=2, | ||
| n_redundant=0, | ||
| n_repeated=0, | ||
| n_classes=2, | ||
| random_state=42, | ||
| ) | ||
| return train_test_split(X, y, test_size=0.2, random_state=42) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def logistic_regression_model(binary_classification_dataset): | ||
| X_train, _, y_train, _ = binary_classification_dataset | ||
| clf = LogisticRegression(max_iter=1000) | ||
| clf.fit(X_train, y_train) | ||
| return clf | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("normalize_scale", [True, False]) | ||
| @pytest.mark.parametrize("plot_chance_level", [True, False]) | ||
| def test_cumulative_accuracy_display_from_predictions( | ||
| binary_classification_dataset, normalize_scale, plot_chance_level | ||
| ): | ||
| _, X_test, _, y_test = binary_classification_dataset | ||
| y_scores = np.random.rand(len(y_test)) | ||
|
|
||
| cap_display = CumulativeAccuracyDisplay.from_predictions( | ||
| y_test, | ||
| y_scores, | ||
| normalize_scale=normalize_scale, | ||
| plot_chance_level=plot_chance_level, | ||
| name="Test Classifier", | ||
| ) | ||
|
|
||
| assert cap_display is not None | ||
| assert hasattr(cap_display, "line_"), "The display must have a line attribute" | ||
| assert hasattr(cap_display, "ax_"), "The display must have an ax attribute" | ||
| assert hasattr(cap_display, "figure_"), "The display must have a figure attribute" | ||
| if plot_chance_level: | ||
| assert ( | ||
| cap_display.chance_level_ is not None | ||
| ), "Chance level line should be present" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("normalize_scale", [True, False]) | ||
| @pytest.mark.parametrize("plot_chance_level", [True, False]) | ||
| def test_cumulative_accuracy_display_from_estimator( | ||
| logistic_regression_model, | ||
| binary_classification_dataset, | ||
| normalize_scale, | ||
| plot_chance_level, | ||
| ): | ||
| _, X_test, _, y_test = binary_classification_dataset | ||
|
|
||
| cap_display = CumulativeAccuracyDisplay.from_estimator( | ||
| logistic_regression_model, | ||
| X_test, | ||
| y_test, | ||
| normalize_scale=normalize_scale, | ||
| plot_chance_level=plot_chance_level, | ||
| name="Logistic Regression", | ||
| ) | ||
|
|
||
| assert cap_display is not None | ||
| assert hasattr(cap_display, "line_"), "The display must have a line attribute" | ||
| assert hasattr(cap_display, "ax_"), "The display must have an ax attribute" | ||
| assert hasattr(cap_display, "figure_"), "The display must have a figure attribute" | ||
| if plot_chance_level: | ||
| assert ( | ||
| cap_display.chance_level_ is not None | ||
| ), "Chance level line should be present" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also extend the tests to check for the presence shape and dtypes of other public attributes such as I might also make sense to check that the values of those attribute match (with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made the change for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've just realised that I've misread it. You meant checking the attributes of y_true_cumulative, not the display itself. |
||
Uh oh!
There was an error while loading. Please reload this page.