Skip to content

plot_roc_curve doesn't correctly infer pos_label #15303

@amueller

Description

@amueller
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression

import numpy as np
X, y = make_blobs(centers=2)
y = y.astype(np.str)
lr = LogisticRegression().fit(X, y)
plot_roc_curve(lr, X, y)
-> raise ValueError("Data is not binary and pos_label is not specified")

I would argue that pos_label=lr.classes_[1] is the right choice here.

cc @thomasjpfan

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions