Skip to content

Commit df00436

Browse files
authored
Unify DbApiHook.run() method with the methods which override it (#23971)
1 parent 31705ed commit df00436

File tree

33 files changed

+307
-264
lines changed

33 files changed

+307
-264
lines changed

airflow/operators/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def __init__(
496496
follow_task_ids_if_false: List[str],
497497
conn_id: str = "default_conn_id",
498498
database: Optional[str] = None,
499-
parameters: Optional[Union[Mapping, Iterable]] = None,
499+
parameters: Optional[Union[Iterable, Mapping]] = None,
500500
**kwargs,
501501
) -> None:
502502
super().__init__(conn_id=conn_id, database=database, **kwargs)

airflow/providers/amazon/aws/operators/redshift_sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
18+
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
1919

2020
from airflow.models import BaseOperator
2121
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
@@ -55,7 +55,7 @@ def __init__(
5555
*,
5656
sql: Union[str, Iterable[str]],
5757
redshift_conn_id: str = 'redshift_default',
58-
parameters: Optional[dict] = None,
58+
parameters: Optional[Union[Iterable, Mapping]] = None,
5959
autocommit: bool = True,
6060
**kwargs,
6161
) -> None:

airflow/providers/amazon/aws/transfers/redshift_to_s3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
unload_options: Optional[List] = None,
9494
autocommit: bool = False,
9595
include_header: bool = False,
96-
parameters: Optional[Union[Mapping, Iterable]] = None,
96+
parameters: Optional[Union[Iterable, Mapping]] = None,
9797
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
9898
**kwargs,
9999
) -> None:

airflow/providers/amazon/aws/transfers/s3_to_redshift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import warnings
19-
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
19+
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union
2020

2121
from airflow.exceptions import AirflowException
2222
from airflow.models import BaseOperator
@@ -140,7 +140,7 @@ def execute(self, context: 'Context') -> None:
140140

141141
copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)
142142

143-
sql: Union[list, str]
143+
sql: Union[str, Iterable[str]]
144144

145145
if self.method == 'REPLACE':
146146
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]

airflow/providers/apache/drill/operators/drill.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
# under the License.
1818
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
1919

20-
import sqlparse
21-
2220
from airflow.models import BaseOperator
2321
from airflow.providers.apache.drill.hooks.drill import DrillHook
2422

@@ -52,7 +50,7 @@ def __init__(
5250
*,
5351
sql: str,
5452
drill_conn_id: str = 'drill_default',
55-
parameters: Optional[Union[Mapping, Iterable]] = None,
53+
parameters: Optional[Union[Iterable, Mapping]] = None,
5654
**kwargs,
5755
) -> None:
5856
super().__init__(**kwargs)
@@ -64,6 +62,4 @@ def __init__(
6462
def execute(self, context: 'Context'):
6563
self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
6664
self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
67-
sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True))
68-
no_term_sql = [s[:-1] for s in sql if s[-1] == ';']
69-
self.hook.run(no_term_sql, parameters=self.parameters)
65+
self.hook.run(self.sql, parameters=self.parameters, split_statements=True)

airflow/providers/apache/drill/provider.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ dependencies:
3434
- apache-airflow>=2.2.0
3535
- apache-airflow-providers-common-sql
3636
- sqlalchemy-drill>=1.1.0
37-
- sqlparse>=0.4.1
3837

3938
integrations:
4039
- integration-name: Apache Drill

airflow/providers/apache/pinot/hooks/pinot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
import subprocess
21-
from typing import Any, Dict, Iterable, List, Optional, Union
21+
from typing import Any, Iterable, List, Mapping, Optional, Union
2222

2323
from pinotdb import connect
2424

@@ -275,7 +275,7 @@ def get_uri(self) -> str:
275275
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
276276
return f'{conn_type}://{host}/{endpoint}'
277277

278-
def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
278+
def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
279279
"""
280280
Executes the sql and returns a set of records.
281281
@@ -287,7 +287,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera
287287
cur.execute(sql)
288288
return cur.fetchall()
289289

290-
def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
290+
def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
291291
"""
292292
Executes the sql and returns the first resulting row.
293293

airflow/providers/common/sql/hooks/sql.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import warnings
1818
from contextlib import closing
1919
from datetime import datetime
20-
from typing import Any, Optional
20+
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union
2121

22+
import sqlparse
2223
from sqlalchemy import create_engine
2324
from typing_extensions import Protocol
2425

@@ -27,6 +28,17 @@
2728
from airflow.providers_manager import ProvidersManager
2829
from airflow.utils.module_loading import import_string
2930

31+
if TYPE_CHECKING:
32+
from sqlalchemy.engine import CursorResult
33+
34+
35+
def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]:
36+
"""Handler for DbApiHook.run() to return results"""
37+
if cursor.returns_rows:
38+
return cursor.fetchall()
39+
else:
40+
return None
41+
3042

3143
def _backported_get_hook(connection, *, hook_params=None):
3244
"""Return hook based on conn_type
@@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None):
201213
cur.execute(sql)
202214
return cur.fetchone()
203215

204-
def run(self, sql, autocommit=False, parameters=None, handler=None):
216+
@staticmethod
217+
def strip_sql_string(sql: str) -> str:
218+
return sql.strip().rstrip(';')
219+
220+
@staticmethod
221+
def split_sql_string(sql: str) -> List[str]:
222+
"""
223+
Splits string into multiple SQL expressions
224+
225+
:param sql: SQL string potentially consisting of multiple expressions
226+
:return: list of individual expressions
227+
"""
228+
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
229+
statements = [s.rstrip(';') for s in splits if s.endswith(';')]
230+
return statements
231+
232+
def run(
233+
self,
234+
sql: Union[str, Iterable[str]],
235+
autocommit: bool = False,
236+
parameters: Optional[Union[Iterable, Mapping]] = None,
237+
handler: Optional[Callable] = None,
238+
split_statements: bool = False,
239+
return_last: bool = True,
240+
) -> Optional[Union[Any, List[Any]]]:
205241
"""
206242
Runs a command or a list of commands. Pass a list of sql
207243
statements to the sql parameter to get them to execute
@@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
213249
before executing the query.
214250
:param parameters: The parameters to render the SQL query with.
215251
:param handler: The result handler which is called with the result of each statement.
216-
:return: query results if handler was provided.
252+
:param split_statements: Whether to split a single SQL string into statements and run separately
253+
:param return_last: Whether to return result for only last statement or for all after split
254+
:return: return only result of the ALL SQL expressions if handler was provided.
217255
"""
218-
scalar = isinstance(sql, str)
219-
if scalar:
220-
sql = [sql]
256+
scalar_return_last = isinstance(sql, str) and return_last
257+
if isinstance(sql, str):
258+
if split_statements:
259+
sql = self.split_sql_string(sql)
260+
else:
261+
sql = [self.strip_sql_string(sql)]
221262

222263
if sql:
223-
self.log.debug("Executing %d statements", len(sql))
264+
self.log.debug("Executing following statements against DB: %s", list(sql))
224265
else:
225266
raise ValueError("List of SQL statements is empty")
226267

@@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
232273
results = []
233274
for sql_statement in sql:
234275
self._run_command(cur, sql_statement, parameters)
276+
235277
if handler is not None:
236278
result = handler(cur)
237279
results.append(result)
238280

239-
# If autocommit was set to False for db that supports autocommit,
240-
# or if db does not supports autocommit, we do a manual commit.
281+
# If autocommit was set to False or db does not support autocommit, we do a manual commit.
241282
if not self.get_autocommit(conn):
242283
conn.commit()
243284

244285
if handler is None:
245286
return None
246-
247-
if scalar:
248-
return results[0]
249-
250-
return results
287+
elif scalar_return_last:
288+
return results[-1]
289+
else:
290+
return results
251291

252292
def _run_command(self, cur, sql_statement, parameters):
253293
"""Runs a statement using an already open cursor."""

airflow/providers/common/sql/provider.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ description: |
2424
versions:
2525
- 1.0.0
2626

27-
dependencies: []
27+
dependencies:
28+
- sqlparse>=0.4.2
2829

2930
additional-extras:
3031
- name: pandas

airflow/providers/databricks/hooks/databricks_sql.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import re
1918
from contextlib import closing
2019
from copy import copy
21-
from typing import Any, Dict, List, Optional, Tuple, Union
20+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
2221

2322
from databricks import sql # type: ignore[attr-defined]
2423
from databricks.sql.client import Connection # type: ignore[attr-defined]
@@ -139,19 +138,15 @@ def get_conn(self) -> Connection:
139138
)
140139
return self._sql_conn
141140

142-
@staticmethod
143-
def maybe_split_sql_string(sql: str) -> List[str]:
144-
"""
145-
Splits strings consisting of multiple SQL expressions into an
146-
TODO: do we need something more sophisticated?
147-
148-
:param sql: SQL string potentially consisting of multiple expressions
149-
:return: list of individual expressions
150-
"""
151-
splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""]
152-
return splits
153-
154-
def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None):
141+
def run(
142+
self,
143+
sql: Union[str, Iterable[str]],
144+
autocommit: bool = False,
145+
parameters: Optional[Union[Iterable, Mapping]] = None,
146+
handler: Optional[Callable] = None,
147+
split_statements: bool = True,
148+
return_last: bool = True,
149+
) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]:
155150
"""
156151
Runs a command or a list of commands. Pass a list of sql
157152
statements to the sql parameter to get them to execute
@@ -163,41 +158,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand
163158
before executing the query.
164159
:param parameters: The parameters to render the SQL query with.
165160
:param handler: The result handler which is called with the result of each statement.
166-
:return: query results.
161+
:param split_statements: Whether to split a single SQL string into statements and run separately
162+
:param return_last: Whether to return result for only last statement or for all after split
163+
:return: return only result of the LAST SQL expression if handler was provided.
167164
"""
165+
scalar_return_last = isinstance(sql, str) and return_last
168166
if isinstance(sql, str):
169-
sql = self.maybe_split_sql_string(sql)
167+
if split_statements:
168+
sql = self.split_sql_string(sql)
169+
else:
170+
sql = [self.strip_sql_string(sql)]
170171

171172
if sql:
172-
self.log.debug("Executing %d statements", len(sql))
173+
self.log.debug("Executing following statements against Databricks DB: %s", list(sql))
173174
else:
174175
raise ValueError("List of SQL statements is empty")
175176

176-
conn = None
177+
results = []
177178
for sql_statement in sql:
178179
# when using AAD tokens, it could expire if previous query run longer than token lifetime
179-
conn = self.get_conn()
180-
with closing(conn.cursor()) as cur:
181-
self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters)
182-
if parameters:
183-
cur.execute(sql_statement, parameters)
184-
else:
185-
cur.execute(sql_statement)
186-
schema = cur.description
187-
results = []
188-
if handler is not None:
189-
cur = handler(cur)
190-
for row in cur:
191-
self.log.debug("Statement results: %s", row)
192-
results.append(row)
193-
194-
self.log.info("Rows affected: %s", cur.rowcount)
195-
if conn:
196-
conn.close()
180+
with closing(self.get_conn()) as conn:
181+
self.set_autocommit(conn, autocommit)
182+
183+
with closing(conn.cursor()) as cur:
184+
self._run_command(cur, sql_statement, parameters)
185+
186+
if handler is not None:
187+
result = handler(cur)
188+
schema = cur.description
189+
results.append((schema, result))
190+
197191
self._sql_conn = None
198192

199-
# Return only result of the last SQL expression
200-
return schema, results
193+
if handler is None:
194+
return None
195+
elif scalar_return_last:
196+
return results[-1]
197+
else:
198+
return results
201199

202200
def test_connection(self):
203201
"""Test the Databricks SQL connection by running a simple query."""

0 commit comments

Comments
 (0)