[MRG + 1] FIX use high precision cumsum and check it is stable enough#7331
[MRG + 1] FIX use high precision cumsum and check it is stable enough#7331lesteve merged 6 commits intoscikit-learn:masterfrom
Conversation
|
|
sklearn/utils/extmath.py
Outdated
| """ | ||
| out = np.cumsum(arr, dtype=np.float64) | ||
| expected = np.sum(arr, dtype=np.float64) | ||
| if not np.allclose(out, expected): |
There was a problem hiding this comment.
if not np.allclose(out[-1], expected):There was a problem hiding this comment.
I thought I committed that change :p
|
Sorry for opening a PR and running away :) |
|
LGTM |
|
Indeed that could be problematic. LGTM |
|
LGTM |
|
Any ideas of pathological cases to test with quickly? |
|
No idea, except comparing |
|
I only think it's an effective test for very large vectors, hence On 6 September 2016 at 20:05, Tom Dupré la Tour notifications@github.com
|
|
how long does a test with a large-enough vector run? |
|
lgtm |
import sys
import time
import numpy as np
import pandas as pd
n_trials = 50
all_results = []
for dtype in [np.float32, np.float64]:
for i in range(3, 8):
n = 1 * 10 ** i
absdiff, reldiff = [], []
s = time.time()
for j in range(n_trials):
x = np.random.rand(n).astype(dtype)
a = np.cumsum(x)[-1]
b = np.sum(x)
absdiff.append(np.abs(a - b))
reldiff.append(absdiff[-1] / b)
all_results.append((dtype.__name__, n, (time.time() - s) / len(results), np.log10(np.mean(absdiff)), np.log10(np.mean(reldiff))))
pd.DataFrame(all_results, columns=['dtype', 'n', 'time', 'log abs diff', 'log rel diff'])
Two questions:
So I'll try build a non-regression test with 1e6 samples. |
|
A regression test for |
|
I've added a direct test of |
|
I suppose the remaining question is: are there other places in the codebase where cumsum is happening and might be over float32 data or very very large arrays? |
|
One of the Travis error will go away if you rebase on master but there is one on Python 2.7 that seems genuine: |
Yeah, I was wondering if there'd be some platform that didn't fail my test.... |
|
I assume that means numpy used to have an unstable implementation of sum, rather than it used to have a stable implementation of cumsum |
|
I wonder: Should I only run the test on recent numpy, or just remove it? |
|
Hm... testing the error message seems slightly odd since it relies on numpy being broken. I guess having a test with a known correct result that wasn't correct before might be nicer, but not offer the error coverage? Like |
|
It might just be my bedtime, but I'm not sure I get what test you're suggesting. |
|
|
|
No, I don't think it will work for |
|
ah... huh... |
|
So just drop the test, I suppose. |
|
And merge the PR? |
|
fine with me |
|
I believe the improved stability of np.sum was done in numpy 1.9: http://docs.scipy.org/doc/numpy/release.html#better-numerical-stability-for-sum-in-some-cases and I also quickly checked that in 1.8 cumsum and sum were giving the same wrong result in one of your snippet but 1.9 was fine. Maybe skip the test for numpy < 1.9? |
sklearn/utils/tests/test_extmath.py
Outdated
|
|
||
|
|
||
| def test_stable_cumsum(): | ||
| if np_version < (1, 19): |
There was a problem hiding this comment.
19 -> 9 otherwise we may wait a while until this test gets run ;-).
There was a problem hiding this comment.
Argh. Making so many errors. Should be in bed. Should take a break from this stuff, too!
|
could please run a check for np1.9, @lesteve? |
Not sure what you exactly mean, but here is what I tried: import numpy as np
np.random.seed(42)
n_samples = 4 * 10 ** 7
y = np.random.randint(2, size=n_samples)
prediction = np.random.normal(size=n_samples) + y * 0.01
trivial_weight = np.ones(n_samples)
print(np.cumsum(trivial_weight.astype('float32'))[-1])
print(np.sum(trivial_weight.astype('float32')))Output for numpy 1.8: Output for numpy 1.9: |
|
I also made sure that the test was run for numpy 1.9. LGTM will wait for AppVeyor and then merge. |
|
Merging, thanks! |
…scikit-learn#7331) * FIX use high precision cumsum and check it is stable enough
…scikit-learn#7331) * FIX use high precision cumsum and check it is stable enough
…scikit-learn#7331) * FIX use high precision cumsum and check it is stable enough
…scikit-learn#7331) * FIX use high precision cumsum and check it is stable enough
Fixes #6842. I don't know a test that will run quickly enough.