1717import warnings
1818from contextlib import closing
1919from datetime import datetime
20- from typing import Any , Optional
20+ from typing import TYPE_CHECKING , Any , Callable , Iterable , List , Mapping , Optional , Tuple , Union
2121
22+ import sqlparse
2223from sqlalchemy import create_engine
2324from typing_extensions import Protocol
2425
2728from airflow .providers_manager import ProvidersManager
2829from airflow .utils .module_loading import import_string
2930
31+ if TYPE_CHECKING :
32+ from sqlalchemy .engine import CursorResult
33+
34+
35+ def fetch_all_handler (cursor : 'CursorResult' ) -> Optional [List [Tuple ]]:
36+ """Handler for DbApiHook.run() to return results"""
37+ if cursor .returns_rows :
38+ return cursor .fetchall ()
39+ else :
40+ return None
41+
3042
3143def _backported_get_hook (connection , * , hook_params = None ):
3244 """Return hook based on conn_type
@@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None):
201213 cur .execute (sql )
202214 return cur .fetchone ()
203215
204- def run (self , sql , autocommit = False , parameters = None , handler = None ):
216+ @staticmethod
217+ def strip_sql_string (sql : str ) -> str :
218+ return sql .strip ().rstrip (';' )
219+
220+ @staticmethod
221+ def split_sql_string (sql : str ) -> List [str ]:
222+ """
223+ Splits string into multiple SQL expressions
224+
225+ :param sql: SQL string potentially consisting of multiple expressions
226+ :return: list of individual expressions
227+ """
228+ splits = sqlparse .split (sqlparse .format (sql , strip_comments = True ))
229+ statements = [s .rstrip (';' ) for s in splits if s .endswith (';' )]
230+ return statements
231+
232+ def run (
233+ self ,
234+ sql : Union [str , Iterable [str ]],
235+ autocommit : bool = False ,
236+ parameters : Optional [Union [Iterable , Mapping ]] = None ,
237+ handler : Optional [Callable ] = None ,
238+ split_statements : bool = False ,
239+ return_last : bool = True ,
240+ ) -> Optional [Union [Any , List [Any ]]]:
205241 """
206242 Runs a command or a list of commands. Pass a list of sql
207243 statements to the sql parameter to get them to execute
@@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
213249 before executing the query.
214250 :param parameters: The parameters to render the SQL query with.
215251 :param handler: The result handler which is called with the result of each statement.
216- :return: query results if handler was provided.
252+ :param split_statements: Whether to split a single SQL string into statements and run separately
253+ :param return_last: Whether to return result for only last statement or for all after split
254+ :return: return only result of the ALL SQL expressions if handler was provided.
217255 """
218- scalar = isinstance (sql , str )
219- if scalar :
220- sql = [sql ]
256+ scalar_return_last = isinstance (sql , str ) and return_last
257+ if isinstance (sql , str ):
258+ if split_statements :
259+ sql = self .split_sql_string (sql )
260+ else :
261+ sql = [self .strip_sql_string (sql )]
221262
222263 if sql :
223- self .log .debug ("Executing %d statements" , len (sql ))
264+ self .log .debug ("Executing following statements against DB: %s " , list (sql ))
224265 else :
225266 raise ValueError ("List of SQL statements is empty" )
226267
@@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
232273 results = []
233274 for sql_statement in sql :
234275 self ._run_command (cur , sql_statement , parameters )
276+
235277 if handler is not None :
236278 result = handler (cur )
237279 results .append (result )
238280
239- # If autocommit was set to False for db that supports autocommit,
240- # or if db does not supports autocommit, we do a manual commit.
281+ # If autocommit was set to False or db does not support autocommit, we do a manual commit.
241282 if not self .get_autocommit (conn ):
242283 conn .commit ()
243284
244285 if handler is None :
245286 return None
246-
247- if scalar :
248- return results [0 ]
249-
250- return results
287+ elif scalar_return_last :
288+ return results [- 1 ]
289+ else :
290+ return results
251291
252292 def _run_command (self , cur , sql_statement , parameters ):
253293 """Runs a statement using an already open cursor."""
0 commit comments