-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
mean_squared_error ignores the squared argument if multioutput="raw_values" #16313
Description
Describe the bug
The mean_squared_error metric ignores the squared argument if multioutput="raw_values".
sklearn.metrics.mean_squared_error(multioutput="raw_values", squared=False) behaves as sklearn.metrics.mean_squared_error(multioutput="raw_values", squared=True).
Suggested fix:
On this line:
https://github.com/scikit-learn/scikit-learn/blob/b194674c4/sklearn/metrics/_regression.py#L258
replace the following:
return output_errors --> return output_errors if squared else np.sqrt(output_errors)
Steps/Code to Reproduce
The bug can be reproduced with the following code:
import sklearn
print("sklearn version:", sklearn.__version__)
print(
"raw values, squared == non-squared: ",
sklearn.metrics.mean_squared_error(
[[1]], [[10]], multioutput="raw_values", squared=True
)
== sklearn.metrics.mean_squared_error(
[[1]], [[10]], multioutput="raw_values", squared=False
),
)
print(
"uniform average, squared == non-squared:",
sklearn.metrics.mean_squared_error(
[[1]], [[10]], multioutput="uniform_average", squared=True
)
== sklearn.metrics.mean_squared_error(
[[1]], [[10]], multioutput="uniform_average", squared=False
),
)Which returns the following output:
sklearn version: 0.22.1
raw values, squared == non-squared: [ True]
uniform average, squared == non-squared: False
Versions
System:
python: 3.7.4 (default, Aug 13 2019, 15:17:50) [Clang 4.0.1 (tags/RELEASE_401/final)]
executable: /Users/usr/.miniconda/envs/env/bin/python
machine: Darwin-19.3.0-x86_64-i386-64bit
Python dependencies:
pip: 19.2.3
setuptools: 41.4.0
sklearn: 0.22.1
numpy: 1.17.2
scipy: 1.3.1
Cython: None
pandas: 0.25.1
matplotlib: 3.1.1
joblib: 0.13.2
Built with OpenMP: True