The validation of the ax parameter to plot_partial_dependence is weirdly inconsistent.
plot_partial_dependence only validates the length the ax if it's a list. But plt.supblots typically return an array.
- the Display object checks for
len, while it should probably check for size. Typically, if I pass in an ax with shape = (2, 2) and only plot 2 PDPs, the display will not complain. However if I ravel this ax to get a shape of (4,), the Display will error.
I think we need a little consistency on the validation.
CC @thomasjpfan
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.datasets import make_regression
from sklearn.inspection import plot_partial_dependence
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
X, y = make_regression()
est = HistGradientBoostingRegressor()
est.fit(X, y)
fig, axes = plt.subplots(nrows=2, ncols=2, squeeze=False)
# uncomment this to get an error
# axes = np.ravel(axes)
plot_partial_dependence(est, X, features=[1, 2], ax=axes)
The validation of the
axparameter toplot_partial_dependenceis weirdly inconsistent.plot_partial_dependenceonly validates the length theaxif it's a list. Butplt.supblotstypically return an array.len, while it should probably check forsize. Typically, if I pass in anaxwith shape = (2, 2) and only plot 2 PDPs, the display will not complain. However if I ravel this ax to get a shape of (4,), the Display will error.I think we need a little consistency on the validation.
CC @thomasjpfan