Skip to content

Replace balanced_accuracy with macro-averaged recall from sklearn #108

@rhiever

Description

@rhiever

From conversations with @amueller, we discovered that "balanced accuracy" (as we've called it) is also known as "macro-averaged recall" as implemented in sklearn. As such, we don't need our own custom implementation of balanced_accuracy in TPOT. Let's refactor TPOT to replace balanced_accuracy with recall_score.

The correct call is:

recall_score(y_test, predictions, average='macro')

where y_test is class and predictions is guess in our case.

Here's some code that compares the two and confirms that they're the same:

from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.metrics import recall_score
import numpy as np
import pandas as pd

digits = load_digits(10)
features, labels = digits['data'], digits['target']

X_train, X_test, y_train, y_test = train_test_split(features, labels, train_size=0.75, test_size=0.25)

clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(X_train, y_train)

def balanced_accuracy(result):
    all_classes = list(set(result['class'].values))
    all_class_accuracies = []
    for this_class in all_classes:
        this_class_accuracy = len(result[(result['guess'] == this_class) & (result['class'] == this_class)])\
            / float(len(result[result['class'] == this_class]))
        all_class_accuracies.append(this_class_accuracy)

    balanced_accuracy = np.mean(all_class_accuracies)

    return balanced_accuracy

predictions = clf.predict(X_test)

print('Macro-averaged recall:\t', recall_score(y_test, predictions, average='macro'))

data = pd.DataFrame({'class': y_test,
                     'guess': predictions})

print('Balanced accuracy:\t', balanced_accuracy(data))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions