44from datetime import datetime
55from typing import (
66 Any ,
7+ AsyncGenerator ,
78 Callable ,
89 Dict ,
910 Generator ,
1213 Optional ,
1314 Sequence ,
1415 Tuple ,
16+ Union ,
1517)
1618
1719import pytz
18- from psycopg import sql
20+ from psycopg import AsyncConnection , sql
1921from psycopg .connection import Connection
20- from psycopg_pool import ConnectionPool
22+ from psycopg_pool import AsyncConnectionPool , ConnectionPool
2123
2224from feast import Entity
2325from feast .feature_view import FeatureView
2426from feast .infra .key_encoding_utils import get_list_val_str , serialize_entity_key
2527from feast .infra .online_stores .online_store import OnlineStore
26- from feast .infra .utils .postgres .connection_utils import _get_conn , _get_connection_pool
28+ from feast .infra .utils .postgres .connection_utils import (
29+ _get_conn ,
30+ _get_conn_async ,
31+ _get_connection_pool ,
32+ _get_connection_pool_async ,
33+ )
2734from feast .infra .utils .postgres .postgres_config import ConnectionType , PostgreSQLConfig
2835from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
2936from feast .protos .feast .types .Value_pb2 import Value as ValueProto
@@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore):
5158 _conn : Optional [Connection ] = None
5259 _conn_pool : Optional [ConnectionPool ] = None
5360
61+ _conn_async : Optional [AsyncConnection ] = None
62+ _conn_pool_async : Optional [AsyncConnectionPool ] = None
63+
5464 @contextlib .contextmanager
5565 def _get_conn (self , config : RepoConfig ) -> Generator [Connection , Any , Any ]:
5666 assert config .online_store .type == "postgres"
@@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
6777 self ._conn = _get_conn (config .online_store )
6878 yield self ._conn
6979
80+ @contextlib .asynccontextmanager
81+ async def _get_conn_async (
82+ self , config : RepoConfig
83+ ) -> AsyncGenerator [AsyncConnection , Any ]:
84+ if config .online_store .conn_type == ConnectionType .pool :
85+ if not self ._conn_pool_async :
86+ self ._conn_pool_async = await _get_connection_pool_async (
87+ config .online_store
88+ )
89+ await self ._conn_pool_async .open ()
90+ connection = await self ._conn_pool_async .getconn ()
91+ yield connection
92+ await self ._conn_pool_async .putconn (connection )
93+ else :
94+ if not self ._conn_async :
95+ self ._conn_async = await _get_conn_async (config .online_store )
96+ yield self ._conn_async
97+
7098 def online_write_batch (
7199 self ,
72100 config : RepoConfig ,
@@ -132,69 +160,107 @@ def online_read(
132160 entity_keys : List [EntityKeyProto ],
133161 requested_features : Optional [List [str ]] = None ,
134162 ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
135- result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
163+ keys = self ._prepare_keys (entity_keys , config .entity_key_serialization_version )
164+ query , params = self ._construct_query_and_params (
165+ config , table , keys , requested_features
166+ )
136167
137- project = config .project
138168 with self ._get_conn (config ) as conn , conn .cursor () as cur :
139- # Collecting all the keys to a list allows us to make fewer round trips
140- # to PostgreSQL
141- keys = []
142- for entity_key in entity_keys :
143- keys .append (
144- serialize_entity_key (
145- entity_key ,
146- entity_key_serialization_version = config .entity_key_serialization_version ,
147- )
148- )
169+ cur .execute (query , params )
170+ rows = cur .fetchall ()
149171
150- if not requested_features :
151- cur .execute (
152- sql .SQL (
153- """
154- SELECT entity_key, feature_name, value, event_ts
155- FROM {} WHERE entity_key = ANY(%s);
156- """
157- ).format (
158- sql .Identifier (_table_id (project , table )),
159- ),
160- (keys ,),
161- )
162- else :
163- cur .execute (
164- sql .SQL (
165- """
166- SELECT entity_key, feature_name, value, event_ts
167- FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s);
168- """
169- ).format (
170- sql .Identifier (_table_id (project , table )),
171- ),
172- (keys , requested_features ),
173- )
172+ return self ._process_rows (keys , rows )
174173
175- rows = cur .fetchall ()
174+ async def online_read_async (
175+ self ,
176+ config : RepoConfig ,
177+ table : FeatureView ,
178+ entity_keys : List [EntityKeyProto ],
179+ requested_features : Optional [List [str ]] = None ,
180+ ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
181+ keys = self ._prepare_keys (entity_keys , config .entity_key_serialization_version )
182+ query , params = self ._construct_query_and_params (
183+ config , table , keys , requested_features
184+ )
176185
177- # Since we don't know the order returned from PostgreSQL we'll need
178- # to construct a dict to be able to quickly look up the correct row
179- # when we iterate through the keys since they are in the correct order
180- values_dict = defaultdict (list )
181- for row in rows if rows is not None else []:
182- values_dict [
183- row [0 ] if isinstance (row [0 ], bytes ) else row [0 ].tobytes ()
184- ].append (row [1 :])
185-
186- for key in keys :
187- if key in values_dict :
188- value = values_dict [key ]
189- res = {}
190- for feature_name , value_bin , event_ts in value :
191- val = ValueProto ()
192- val .ParseFromString (bytes (value_bin ))
193- res [feature_name ] = val
194- result .append ((event_ts , res ))
195- else :
196- result .append ((None , None ))
186+ async with self ._get_conn_async (config ) as conn :
187+ async with conn .cursor () as cur :
188+ await cur .execute (query , params )
189+ rows = await cur .fetchall ()
190+
191+ return self ._process_rows (keys , rows )
192+
193+ @staticmethod
194+ def _construct_query_and_params (
195+ config : RepoConfig ,
196+ table : FeatureView ,
197+ keys : List [bytes ],
198+ requested_features : Optional [List [str ]] = None ,
199+ ) -> Tuple [sql .Composed , Union [Tuple [List [bytes ], List [str ]], Tuple [List [bytes ]]]]:
200+ """Construct the SQL query based on the given parameters."""
201+ if requested_features :
202+ query = sql .SQL (
203+ """
204+ SELECT entity_key, feature_name, value, event_ts
205+ FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
206+ """
207+ ).format (
208+ sql .Identifier (_table_id (config .project , table )),
209+ )
210+ params = (keys , requested_features )
211+ else :
212+ query = sql .SQL (
213+ """
214+ SELECT entity_key, feature_name, value, event_ts
215+ FROM {} WHERE entity_key = ANY(%s);
216+ """
217+ ).format (
218+ sql .Identifier (_table_id (config .project , table )),
219+ )
220+ params = (keys , [])
221+ return query , params
222+
223+ @staticmethod
224+ def _prepare_keys (
225+ entity_keys : List [EntityKeyProto ], entity_key_serialization_version : int
226+ ) -> List [bytes ]:
227+ """Prepare all keys in a list to make fewer round trips to the database."""
228+ return [
229+ serialize_entity_key (
230+ entity_key ,
231+ entity_key_serialization_version = entity_key_serialization_version ,
232+ )
233+ for entity_key in entity_keys
234+ ]
235+
236+ @staticmethod
237+ def _process_rows (
238+ keys : List [bytes ], rows : List [Tuple ]
239+ ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
240+ """Transform the retrieved rows in the desired output.
197241
242+ PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict`
243+ is created to quickly look up the correct row using the keys, since these are
244+ actually in the correct order.
245+ """
246+ values_dict = defaultdict (list )
247+ for row in rows if rows is not None else []:
248+ values_dict [
249+ row [0 ] if isinstance (row [0 ], bytes ) else row [0 ].tobytes ()
250+ ].append (row [1 :])
251+
252+ result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
253+ for key in keys :
254+ if key in values_dict :
255+ value = values_dict [key ]
256+ res = {}
257+ for feature_name , value_bin , event_ts in value :
258+ val = ValueProto ()
259+ val .ParseFromString (bytes (value_bin ))
260+ res [feature_name ] = val
261+ result .append ((event_ts , res ))
262+ else :
263+ result .append ((None , None ))
198264 return result
199265
200266 def update (
0 commit comments