-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
nearest_centroid.py is iterating over all Y labels instead of classes #4074
Copy link
Copy link
Closed
Labels
Description
In nearest_centroid function fit, it is looping over all Y's
for cur_class in y_ind:
center_mask = y_ind == cur_class
nk[cur_class] = np.sum(center_mask)
if is_X_sparse:
center_mask = np.where(center_mask)[0]
But I think the idea is to compute the centroid for each possible classes. And by looping over "y_ind" it essentially goes over all Y's. Thus the complexity becomes O(N_N_M) where N is number of examples and M is number of features. My understanding is it should be O(N_M_K) where K is number of unique classes. Thus the code should really be
for cur_class in self.classes_:
center_mask = y_ind == cur_class
nk[cur_class] = np.sum(center_mask)
if is_X_sparse:
center_mask = np.where(center_mask)[0]
Reactions are currently unavailable