Skip to content

check_decision_proba_consistency fails with LinearDiscriminantAnalysis #19224

@ogrisel

Description

@ogrisel

The following common check recently started to fail on a Windows CI job:

name = 'LinearDiscriminantAnalysis'
estimator_orig = LinearDiscriminantAnalysis()

    @ignore_warnings(category=FutureWarning)
    def check_decision_proba_consistency(name, estimator_orig):
        # Check whether an estimator having both decision_function and
        # predict_proba methods has outputs with perfect rank correlation.
    
        centers = [(2, 2), (4, 4)]
        X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
                          centers=centers, cluster_std=1.0, shuffle=True)
        X_test = np.random.randn(20, 2) + 4
        estimator = clone(estimator_orig)
    
        if (hasattr(estimator, "decision_function") and
                hasattr(estimator, "predict_proba")):
    
            estimator.fit(X, y)
            # Since the link function from decision_function() to predict_proba()
            # is sometimes not precise enough (typically expit), we round to the
            # 10th decimal to avoid numerical issues.
            a = estimator.predict_proba(X_test)[:, 1].round(decimals=10)
            b = estimator.decision_function(X_test).round(decimals=10)
>           assert_array_equal(rankdata(a), rankdata(b))
E           AssertionError: 
E           Arrays are not equal
E           
E           Mismatched elements: 2 / 20 (10%)
E           Max absolute difference: 0.5
E           Max relative difference: 0.02631579
E            x: array([ 7. ,  8. , 11. ,  9. , 17. , 10. ,  5. , 14. ,  6. ,  1. , 19.5,
E                   4. ,  2. , 16. , 12. , 13. ,  3. , 15. , 19.5, 18. ])
E            y: array([ 7.,  8., 11.,  9., 17., 10.,  5., 14.,  6.,  1., 20.,  4.,  2.,
E                  16., 12., 13.,  3., 15., 19., 18.])

This happened on this PR which should not have any impact on the behavior LinearDiscriminantAnalysis.predict_proba: #17743.

https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=25474&view=logs&j=d32b16b6-cb9d-571b-e765-de83708fb8dd&t=b93f76c1-c2c9-579e-c2ec-c4f438af1261

I suspect the test to be too brittle. Maybe using a test set more related to the original distribution (blobs) would avoid ties or caused by arbitrary rounding?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions