-
Notifications
You must be signed in to change notification settings - Fork 217
Expand file tree
/
Copy pathkmeans.py
More file actions
103 lines (90 loc) · 3.02 KB
/
kmeans.py
File metadata and controls
103 lines (90 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Purpose of this query module is to offer easy kmeans clustering algorithm on top of the embeddings that you
might have stored in nodes. All you need to do is call kmeans.get_clusters(5, "embedding") where 5
represents number of clusters you want to get, and "embedding" represents node property name in which
embedding of node is stored
"""
from typing import List, Tuple
import mgp
from sklearn.cluster import KMeans
def get_created_clusters(
number_of_clusters: int,
embeddings: List[List[float]],
nodes: List[mgp.Vertex],
init: str,
n_init: int,
max_iter: int,
tol: float,
algorithm: str,
random_state: int,
) -> List[Tuple[mgp.Vertex, int]]:
kmeans = KMeans(
n_clusters=number_of_clusters,
init=init,
n_init=n_init,
max_iter=max_iter,
tol=tol,
algorithm=algorithm,
random_state=random_state,
).fit(embeddings)
return [(nodes[i], label) for i, label in enumerate(kmeans.labels_)]
def extract_nodes_embeddings(ctx: mgp.ProcCtx, embedding_property: str) -> Tuple[List[mgp.Vertex], List[List[float]]]:
nodes = []
embeddings = []
for node in ctx.graph.vertices:
nodes.append(node)
embeddings.append(node.properties.get(embedding_property))
return nodes, embeddings
@mgp.read_proc
def get_clusters(
ctx: mgp.ProcCtx,
n_clusters: mgp.Number,
embedding_property: str = "embedding",
init: str = "k-means++",
n_init: mgp.Number = 10,
max_iter: mgp.Number = 10,
tol: mgp.Number = 1e-4,
algorithm: str = "lloyd",
random_state: int = 1998,
) -> mgp.Record(node=mgp.Vertex, cluster_id=mgp.Number):
nodes, embeddings = extract_nodes_embeddings(ctx, embedding_property)
nodes_labels_list = get_created_clusters(
number_of_clusters=n_clusters,
embeddings=embeddings,
nodes=nodes,
init=init,
n_init=n_init,
max_iter=max_iter,
tol=tol,
algorithm=algorithm,
random_state=random_state,
)
return [mgp.Record(node=node, cluster_id=int(label)) for node, label in nodes_labels_list]
@mgp.write_proc
def set_clusters(
ctx: mgp.ProcCtx,
n_clusters: mgp.Number,
embedding_property: str = "embedding",
cluster_property="cluster_id",
init: str = "k-means++",
n_init: mgp.Number = 10,
max_iter: mgp.Number = 10,
tol: mgp.Number = 1e-4,
algorithm: str = "lloyd",
random_state=1998,
) -> mgp.Record(node=mgp.Vertex, cluster_id=mgp.Number):
nodes, embeddings = extract_nodes_embeddings(ctx, embedding_property)
nodes_labels_list = get_created_clusters(
number_of_clusters=n_clusters,
embeddings=embeddings,
nodes=nodes,
init=init,
n_init=n_init,
max_iter=max_iter,
tol=tol,
algorithm=algorithm,
random_state=random_state,
)
for vertex, label in nodes_labels_list:
vertex.properties.set(cluster_property, int(label))
return [mgp.Record(node=node, cluster_id=int(label)) for node, label in nodes_labels_list]