Skip to content

FIX use specific threshold to discard eigenvalues with 32 bits fp#18149

Merged
rth merged 3 commits intoscikit-learn:masterfrom
smarie:fix_issue_18146
Aug 13, 2020
Merged

FIX use specific threshold to discard eigenvalues with 32 bits fp#18149
rth merged 3 commits intoscikit-learn:masterfrom
smarie:fix_issue_18146

Conversation

@smarie
Copy link
Copy Markdown
Contributor

@smarie smarie commented Aug 13, 2020

Reference Issues/PRs

Fixes #18146

What does this implement/fix? Explain your changes.

Fixed 32/64bit consistency for KernelPCA and other models using _check_psd_eigenvalues. Small positive eigenvalues were not correctly discarded by _check_psd_eigenvalues for 32bit data.

Any other comments?

Note that all of these rules are just "rule of thumb" validated/confirmed by experience (although strongly influenced by the minimum value available in single and double precision) so it is very important to add this example (and other failing ones if we find any) in the test harness to ensure non-regression over time.

Reminder:

  • single-precision: np.finfo('float32').eps = 1.2e-07
  • double-precision: np.finfo('float64').eps = 2.2e-16

Another note: we could imagine to extend the proposed test to generate various toy dataset configurations, in order to be even more sure of the thresholds used.

Sylvain MARIE added 3 commits August 13, 2020 11:04
…heck_psd_eigenvalues`. Small positive eigenvalues were not correctly discarded by `_check_psd_eigenvalues` for 32bit data. Fixes scikit-learn#18146
@smarie smarie changed the title Fix issue 18146 [MRG] Fix issue 18146 Aug 13, 2020
@smarie
Copy link
Copy Markdown
Contributor Author

smarie commented Aug 13, 2020

@glemaitre > ready for review

@glemaitre
Copy link
Copy Markdown
Member

That's was super fast :). I will check this in the afternoon. Thanks @smarie

Copy link
Copy Markdown
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. ping @rth @adrinjalali @NicolasHug for a second quick review

@glemaitre glemaitre changed the title [MRG] Fix issue 18146 FIX use specific threshold to discard eigenvalues with 32 bits floating precision Aug 13, 2020
Copy link
Copy Markdown
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @smarie !

@rth rth changed the title FIX use specific threshold to discard eigenvalues with 32 bits floating precision FIX use specific threshold to discard eigenvalues with 32 bits fp Aug 13, 2020
@rth rth merged commit acf195c into scikit-learn:master Aug 13, 2020
@smarie
Copy link
Copy Markdown
Contributor Author

smarie commented Aug 13, 2020

You're welcome :)

@smarie smarie deleted the fix_issue_18146 branch August 13, 2020 12:28
@alfaro96
Copy link
Copy Markdown
Member

@smarie
Copy link
Copy Markdown
Contributor Author

smarie commented Aug 26, 2020

@alfaro96 , the issue seems only to happen on mac os x but only for python >= 3.7. This is therefore probably related to a very low-level numerical computation behaviour in the numpy or linear algebra libraries, that is different in this versions.
Could anyone with a mac try and

  • run this specific test several times( 10, 20): does the issue always happen or only from time to time ?
  • put a breakpoint in _check_psd_eigenvalues and copy below in this thread the arrays received by the function, in both cases (the test calls this function twice, once for 64 and once for 32 bits). For reference, the values on windows 7 are Something goes wrong with KernelPCA with 32 bits input data #18146 (comment)

From all of this we'll be able to determine if raising a little bit the 1e-7 threshold in https://github.com/scikit-learn/scikit-learn/pull/18149/files#diff-3b70045c110a30d29de66ed0ea3fb86dR1186 is the right thing to do, or if something more nasty happens depending on the platform (and in that case this is probably good news to know it)

@alfaro96
Copy link
Copy Markdown
Member

@alfaro96 , the issue seems only to happen on mac os x but only for python >= 3.7. This is therefore probably related to a very low-level numerical computation behaviour in the numpy or linear algebra libraries, that is different in this versions.
Could anyone with a mac try and

  • run this specific test several times( 10, 20): does the issue always happen or only from time to time ?

I could not reproduce this issue in my machine (macOS-10.15.6-x86_64-i386-64bit). Nevertheless, I am experiencing the same issue here: #17921.

So, I assume that this is "generalized" problem.

  • put a breakpoint in _check_psd_eigenvalues and copy below in this thread the arrays received by the function, in both cases (the test calls this function twice, once for 64 and once for 32 bits). For reference, the values on windows 7 are #18146 (comment)

These are the arrays for the macos-latest environment raising the error:

# 64 bit data
array([-2.75182923e-14 -2.82984727e-15 -2.32422419e-15 -2.28444794e-15,
       -1.88368307e-15 -1.03866288e-15 -9.64184248e-16 -8.61787743e-16,
       -8.00057562e-16 -5.87978712e-16 -3.30763427e-16 -2.54953209e-16,
       -1.82340883e-16 -1.56598272e-16  9.53318764e-18  8.86546109e-17,
        2.40790473e-16  2.74047675e-16  5.38711820e-16  8.08403061e-16,
        1.06110990e-15  1.32942923e-15  1.53738341e-15  2.06144632e-15,
        3.44642618e-15  3.58062038e-15  5.06066380e-14  9.09952230e-01,
        1.30832231e+00  8.77817255e+01])

# 32 bit data
array([-5.4791326e-06 -3.4567588e-06 -1.8017618e-06 -1.0867386e-06
       -4.6405938e-07 -3.6561633e-07 -3.5079950e-07 -1.8227604e-07
       -1.6541932e-07 -1.1700362e-07  1.6083459e-07  2.1471459e-07
        2.6229949e-07  2.7478683e-07  3.4371274e-07  3.7471702e-07
        5.2492760e-07  6.4274491e-07  6.8415738e-07  7.0834585e-07
        9.6557756e-07  1.1529023e-06  1.3964327e-06  1.4208520e-06
        1.5312862e-06  3.3968322e-06  8.9305595e-06  9.0995252e-01
        1.3083217e+00  8.7781647e+01])

From all of this we'll be able to determine if raising a little bit the 1e-7 threshold in https://github.com/scikit-learn/scikit-learn/pull/18149/files#diff-3b70045c110a30d29de66ed0ea3fb86dR1186 is the right thing to do, or if something more nasty happens depending on the platform (and in that case this is probably good news to know it)

@smarie
Copy link
Copy Markdown
Contributor Author

smarie commented Aug 26, 2020

ok so if I'm not mistaken, the ratio between 4th and 1st eigenvalue is 8.9305595e-06 / 8.7781647e+01 so it is slightly above 1e-07: 1.0173606676575571e-07

Setting the threshold to 5e-7 instead of 1e-7 in https://github.com/scikit-learn/scikit-learn/pull/18149/files#diff-3b70045c110a30d29de66ed0ea3fb86dR1186 should therefore solve the problem with enough margin. What do you think ?

(the new line should be small_pos_ratio = 1e-12 if is_double_precision else 5e-7)

@alfaro96
Copy link
Copy Markdown
Member

ok so if I'm not mistaken, the ratio between 4th and 1st eigenvalue is 8.9305595e-06 / 8.7781647e+01 so it is slightly above 1e-07: 1.0173606676575571e-07

You are right, the ratio is slightly above the small_pos_ratio threshold.

Setting the threshold to 5e-7 instead of 1e-7 in https://github.com/scikit-learn/scikit-learn/pull/18149/files#diff-3b70045c110a30d29de66ed0ea3fb86dR1186 should therefore solve the problem with enough margin. What do you think ?

Changing 1e-7 by 5e-7 should definitely solve the issue, but maybe we can use a smaller threshold. For instance, 2e-7 instead of 5e-7.

(the new line should be small_pos_ratio = 1e-12 if is_double_precision else 5e-7)

Do you mind to submit a PR to solve this issue?

Thank you @smarie!

smarie pushed a commit to smarie/scikit-learn that referenced this pull request Aug 27, 2020
@smarie
Copy link
Copy Markdown
Contributor Author

smarie commented Aug 27, 2020

Done in #18270. Let me know how this goes @alfaro96

jayzed82 pushed a commit to jayzed82/scikit-learn that referenced this pull request Oct 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Something goes wrong with KernelPCA with 32 bits input data

4 participants