Skip to content

nearest_centroid.py is iterating over all Y labels instead of classes #4074

@oeddyo

Description

@oeddyo

In nearest_centroid function fit, it is looping over all Y's

    for cur_class in y_ind:
        center_mask = y_ind == cur_class
        nk[cur_class] = np.sum(center_mask)
        if is_X_sparse:
            center_mask = np.where(center_mask)[0]

But I think the idea is to compute the centroid for each possible classes. And by looping over "y_ind" it essentially goes over all Y's. Thus the complexity becomes O(N_N_M) where N is number of examples and M is number of features. My understanding is it should be O(N_M_K) where K is number of unique classes. Thus the code should really be

    for cur_class in self.classes_:
        center_mask = y_ind == cur_class
        nk[cur_class] = np.sum(center_mask)
        if is_X_sparse:
            center_mask = np.where(center_mask)[0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions