Skip to content

FIX ignore single node trees in gbm's feature importances#13620

Merged
jnothman merged 8 commits intoscikit-learn:masterfrom
adrinjalali:gbm/feature_importances
Apr 16, 2019
Merged

FIX ignore single node trees in gbm's feature importances#13620
jnothman merged 8 commits intoscikit-learn:masterfrom
adrinjalali:gbm/feature_importances

Conversation

@adrinjalali
Copy link
Copy Markdown
Member

Fixes #7406

I'm not sure how to test it though. The feature importances of the original post's example don't change after this PR.

@jnothman
Copy link
Copy Markdown
Member

So does this fix the issue?

@adrinjalali
Copy link
Copy Markdown
Member Author

I believe so, but I'd be happier if we could actually test it. @NicolasHug do you happen to have an example for this case?

@NicolasHug
Copy link
Copy Markdown
Member

NicolasHug commented Apr 12, 2019

I don't understand what the issue is here.

The original issue (FA not summing to 1 have been fixed). #7406 (comment)

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=10, random_state=1)
gbc = GradientBoostingClassifier(min_samples_leaf=5)

gbc.fit(X, y)
for stage in gbc.estimators_:
    for tree in stage:
        print(tree.tree_.node_count)
        print(tree.tree_.compute_feature_importances())

print('-' * 10)
print(gbc.feature_importances_)

This will give you an ensemble with 3-node trees and 1-node trees at the end. The FAs look good.

Details
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
----------
[0.         0.         0.         0.         0.         0.
 0.         0.         0.45126063 0.         0.54873937 0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]

@adrinjalali
Copy link
Copy Markdown
Member Author

In your example, the FEs are the same before and after this PR, because there's only one class, and hence all stages have only one tree. If you change that, then the FEs will differ. For example:

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=15, n_informative=3, random_state=1, n_classes=3)
gbc = GradientBoostingClassifier(min_samples_leaf=7, random_state=42, n_estimators=300)

gbc.fit(X, y)
for stage in gbc.estimators_:
    print("==========================")
    for tree in stage:
        print(tree.tree_.node_count)
        print(tree.tree_.compute_feature_importances())

print('-' * 10)
print(gbc.feature_importances_)

will report the following two sets of FEs before and after this PR. They're not too different in this case, but they're not the same:

[0.00000000e+00 2.96624006e-02 0.00000000e+00 0.00000000e+00
 8.03712487e-02 3.38731483e-02 0.00000000e+00 0.00000000e+00
 8.07677841e-02 4.80279915e-06 1.86077175e-07 2.90284866e-01
 0.00000000e+00 1.00705503e-01 0.00000000e+00 3.27791464e-02
 0.00000000e+00 2.43740224e-03 2.95737733e-01 5.33757789e-02]

[0.00000000e+00 2.96624317e-02 0.00000000e+00 0.00000000e+00
 8.03712358e-02 3.38731879e-02 0.00000000e+00 0.00000000e+00
 8.07677823e-02 4.82758884e-06 1.86077085e-07 2.90284856e-01
 0.00000000e+00 1.00705503e-01 0.00000000e+00 3.27791688e-02
 0.00000000e+00 2.43740107e-03 2.95737592e-01 5.33758276e-02]

normalize=False) for tree in trees) / len(trees)
total_sum += stage_sum

importances = total_sum / total_sum.sum()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will give a ZeroDivision error if all the trees in the ensemble have only one node. You can reproduce with something like:

X, y = make_classification(n_samples=10, random_state=1)
gbc = GradientBoostingClassifier(min_samples_leaf=10)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's why I have the if len(trees) == 0: continue above it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and if you hit continue for all trees then total_sum.sum() is 0 and you get an error

@adrinjalali
Copy link
Copy Markdown
Member Author

@NicolasHug like this?

continue
stage_sum = sum(tree.tree_.compute_feature_importances(
normalize=False) for tree in stage) / len(stage)
normalize=False) for tree in trees) / len(trees)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is correct: this will give more weight to the stages where one (or more) tree is "missing" (because it's been filtered out). Let's say you have 3 classes and one of the trees is filtered out at a given stage, stage_sum will be an average over 2 trees instead of an average over 3 trees. Hence these 2 trees will have more weight than the other trees at the other stages.

After all, we just want to compute the average feature importances across all trees that have more than one node right? I feel like this does want we want:

        relevant_trees = [tree
                          for stage in self.estimators_ for tree in stage
                          if tree.tree_.node_count > 1]
        if not relevant_trees:
            # degenerate case where all trees have only one node
            return np.zeros(shape=self.n_features_)

        relevant_feature_importances = [
            tree.tree_.compute_feature_importances(normalize=False)
            for tree in relevant_trees
        ]
        avg_feature_importances = np.mean(relevant_feature_importances, axis=0)
        return avg_feature_importances / np.sum(avg_feature_importances)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, changed it to your code

if tree.tree_.node_count > 1]
if not relevant_trees:
# degenerate case where all trees have only one node
return np.zeros(shape=self.n_features_)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can test this by growing single node trees:

x = np.zeros((10, 10))
y = np.ones((10,))
gbr = GradientBoostingRegressor()
gbr.fit(X, y)

Do we want this degenerate case to also sum to one?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to return non-zero values for the degenerate case. At least in my mind, the zero vector is a better representative of what we have. It's like, are features equally important, or are they not important at all? I think they're not important at all in this case.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this does not sum to one, we should document it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to the docstring

Copy link
Copy Markdown
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want a non-regression test making sure we ignore the single-node-trees, you could simply compute the FIs starting with

relevant_trees = [tree
                  for stage in self.estimators_ for tree in stage]

without filtering and asserting results are different.

LGTM anyway.


Some additional thoughts: I'm not a FI expert, but maybe in some cases one would want the intra-class FIs. In multiclass classification we build C trees per iteration, and for some datasets I wouldn't be surprised that the FIs differ quite a bit for each class.

Right now we're averaging over all the (non-single-node) trees without considering the classes. Other aggregating strategies would be to first compute intra-class FIs and only do a (weighted) average after. Or even return an array of shape (C, n_features). Anyway, just some thoughts, nothing that would prevent this PR from being merged ;)

@adrinjalali
Copy link
Copy Markdown
Member Author

Right now we're averaging over all the (non-single-node) trees without considering the classes. Other aggregating strategies would be to first compute intra-class FIs and only do a (weighted) average after. Or even return an array of shape (C, n_features). Anyway, just some thoughts, nothing that would prevent this PR from being merged ;)

I really like that, but it's backward incompatible, and since feature_importances_ is a property and not a function, it's even trickier to have the deprecation cycle on it, or to add an input parameter. It's kinda related to #9606

@adrinjalali
Copy link
Copy Markdown
Member Author

I just realized we have the same issue in forest.py and potentially in weight_boosting.py.

:class:`ensemble.GradientBoostingRegressor` now ignores the single
node trees in feature importance calculation, and returns an
array of zeros if all trees have only one single node (i.e. the root node).
:issue:`13620` by `Adrin Jalali`_.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug I'm not sure how you'd like to merge the two changeslogs here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the other PR:

  - sum up to ``1``
  - all the single node trees in feature importance calculation are ignored
  - in case all trees have only one single node (i.e. a root node),
    feature importances will be an array of all zeros.

All these cases apply here it seems?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the first one, but merged them anyway.

@adrinjalali adrinjalali force-pushed the gbm/feature_importances branch from 733a2fb to 6720395 Compare April 15, 2019 15:04
@jnothman jnothman merged commit fd93ea0 into scikit-learn:master Apr 16, 2019
@jnothman
Copy link
Copy Markdown
Member

Thanks @adrinjalali

@adrinjalali adrinjalali deleted the gbm/feature_importances branch April 16, 2019 07:10
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Apr 25, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in Gradient Boosting: Feature Importances do not sum to 1

4 participants