[MRG] Fix for float16 overflow on accumulator operations#13010
[MRG] Fix for float16 overflow on accumulator operations#13010rth merged 6 commits intoscikit-learn:masterfrom
Conversation
rth
left a comment
There was a problem hiding this comment.
Thank you @baluyotraf !
It might be good to add a non regression test for overflow in StandardScaler with float16.
sklearn/utils/__init__.py
Outdated
| # Use at least float64 for the accumulating functions to avoid precision issues; | ||
| # see https://github.com/numpy/numpy/issues/9393 | ||
| # The float64 is also retained as it is in case the float overflows | ||
| def safe_acc_op(op, x, *args, **kwargs): |
There was a problem hiding this comment.
Please make this private, maybe more verbose (_safe_accumulate_op) and move it to utils.extmath.
…tils.extmath. Also fixed some line lengths to fit the 80 limit (scikit-learn#13007)
|
Moved the function to extmath and added the test. I also verified that the test fails on master and that it passes in this branch. Thanks for the review. o/ |
jnothman
left a comment
There was a problem hiding this comment.
Thanks!
Please add an entry to the change log at doc/whats_new/v0.21.rst. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:
sklearn/utils/extmath.py
Outdated
| updated_variance = None | ||
| else: | ||
| new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count | ||
| new_unnormalized_variance = \ |
There was a problem hiding this comment.
We prefer line continuations to use parentheses rather than backslash where possible.
There was a problem hiding this comment.
I think I saw a backslash someone so I kind of went along with it. I'll take note of this.
| # Overflow calculations may cause -inf, inf, or nan. Since there is no nan | ||
| # input, all of the outputs should be finite. This may be redundant since a | ||
| # FloatingPointError exception will be thrown on overflow above. | ||
| assert np.all(np.isfinite(X_scaled)) |
There was a problem hiding this comment.
I think it makes more sense to check that the output is identical to when the input is high precision. Also may want to check that the scaler features are preserving the input dtype (although surely we have another test for that)
There was a problem hiding this comment.
I tested it out before and found that output is off after 2 or 3 decimal points. Should we cast the input during fit and cast it back to float16? It's kind of similar with to #12333 only this time the imprecision is with the results rather than the mean.
There was a problem hiding this comment.
Wouldn't you expect it to be off after 2 or 3 decimal points with float16?
There was a problem hiding this comment.
Would a test like this be enough?
def test_scaler_float16_overflow():
# Test if the scaler will not overflow on float16 numpy arrays
rng = np.random.RandomState(0)
# float16 has a maximum of 65500.0. On the worst case 5 * 200000 is 100000
# which is enough to overflow the data type
X = rng.uniform(5, 10, [200000, 1]).astype(np.float16)
with np.errstate(over='raise'):
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)
# Calculate the float64 equivalent to verify result
X_scaled_f64 = StandardScaler().fit_transform(X.astype(np.float64))
# Overflow calculations may cause -inf, inf, or nan. Since there is no nan
# input, all of the outputs should be finite. This may be redundant since a
# FloatingPointError exception will be thrown on overflow above.
assert np.all(np.isfinite(X_scaled))
# The normal distribution is very unlikely to go above 4. At 4.0-8.0 the
# float16 precision is 2^-8 which is around 0.004. Thus only 2 decimals are
# checked to account for precision differences.
assert_array_almost_equal(X_scaled, X_scaled_f64, decimal=2)|
There are CI failures, btw. |
|
Kind of you to show your working. Looks great (especially if it also
passes)!
|
…ult with respect to their precisions (scikit-learn#13007)
|
this did not fix #5602 ? |
…ler (scikit-learn#13010)" This reverts commit 2ff7649.
…ler (scikit-learn#13010)" This reverts commit 2ff7649.
Reference Issues/PRs
This fixes #13007
What does this implement/fix? Explain your changes.
A dtype of float64 is passed when using numpy based accumulator functions to prevent overflow. This is only done for floating point inputs.