Skip to content

DummyClassifier bug with putting arrays into lists #10786

@lrossert

Description

@lrossert

Description

/site-packages/sklearn/dummy.py

In sklearn.dummy the class DummyClassifier has an issue when the predict function is called when the attribute n_outputs_ is 1. The error is around line 223 of dummy.py

               y = np.tile([classes_[k][class_prior_[k].argmax()] for
                             k in range(self.n_outputs_)], [n_samples, 1])

Since earlier around line 189 the attibute classes_ and class_prior_ are converted to lists:

               classes_ = [classes_]
               class_prior_ = [class_prior_]

I think this gives rise to the issue where an error is returned as argmax() does not work on lists. Here argmax() is not called on an array but on a list. This is because the array is put into a list in line 189.

A hotfix I've used locally is to change the source code on line 223 to unlist/use the first element of a list for classes_ and n_outputs_. This looks like:

               y = np.tile([classes_[0][k][class_prior_[0][k].argmax()] for
                             k in range(self.n_outputs_)], [n_samples, 1])

This error occurs when the DummyClassifier strategy variable is set to mose_frequent:

DummyClassifier(strategy='most_frequent')

But may also occur on some of the other DummyClassifier strategy options.

Steps/Code to Reproduce

import pandas as pd
from sklearn import datasets
from math import floor

from sklearn.dummy import DummyClassifier

############## Get the test train split #############
# Load the breast_cancer dataset
breast_cancer = datasets.load_breast_cancer()

# data
feature_names = pd.Series(breast_cancer.feature_names)
breast_cancer_data = breast_cancer.data
breast_cancer_data = pd.DataFrame(breast_cancer_data)
breast_cancer_data = breast_cancer_data.rename(columns=feature_names)

# response
breast_cancer_response = breast_cancer.target
breast_cancer_response = pd.DataFrame(breast_cancer_response)
breast_cancer_response = breast_cancer_response.rename(columns={0:'response'})

# since the data is randomly shuffled we take the last 20 values as out test 
#set
# note that for more rigourous testing we couls use SKLearn's train_test_split
# function. But we are not focussing on cleaning here

# we split into 70%, 30% test train
length = len(breast_cancer_data)
train_length = floor(length*0.3)

# Split the data into training/testing sets
breast_cancer_data_train = breast_cancer_data[:-train_length ]
breast_cancer_data_test = breast_cancer_data[-train_length :]

# Split the targets into training/testing sets
breast_cancer_response_train = breast_cancer_response[:-train_length ]
breast_cancer_response_test = breast_cancer_response[-train_length :]

##################### DummyClassifier Issue ################

dummy = DummyClassifier(strategy='most_frequent')
dummy.fit(X=breast_cancer_data_train, y=breast_cancer_response_train)

# error occurs on the following line
dummy_pred = dummy.predict(breast_cancer_data_test)

#dummy.n_outputs_
#[dummy.classes_][0][0][[dummy.class_prior_][0][0].argmax()]

Expected Results

Expect an array. In this case an array made up of only [1]s:

array([[1],
       .
       .
       .
       [1]])

Actual Results

  File "C:/Users/lancelot.rossert/Documents/Keyrus/Blog posts/dummy_issue.py", line 53, in <module>
    dummy_pred = dummy.predict(breast_cancer_data_test)

  File "C:\Users\lancelot.rossert\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\dummy.py", line 224, in predict
    k in range(self.n_outputs_)], [n_samples, 1])

  File "C:\Users\lancelot.rossert\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\dummy.py", line 224, in <listcomp>
    k in range(self.n_outputs_)], [n_samples, 1])

AttributeError: 'list' object has no attribute 'argmax'

Versions

Windows-7-6.1.7601-SP1
Python 3.6.2 |Anaconda custom (64-bit)| (default, Sep 19 2017, 08:03:39) [MSC v.1900 64 bit (AMD64)]
NumPy 1.13.3
SciPy 1.0.0
Scikit-Learn 0.19.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    BugEasyWell-defined and straightforward way to resolve

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions