Skip to content

Commit dcf8743

Browse files
authored
[AIRFLOW-6894] Prevent db query in example_dags (#7516)
1 parent a812f48 commit dcf8743

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

airflow/providers/google/cloud/operators/cloud_sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def __init__(self,
808808
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
809809
self.autocommit = autocommit
810810
self.parameters = parameters
811-
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
811+
self.gcp_connection = None
812812

813813
def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: Union[PostgresHook, MySqlHook]):
814814
cloud_sql_proxy_runner = None
@@ -827,6 +827,7 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: Union[Postgr
827827
cloud_sql_proxy_runner.stop_proxy()
828828

829829
def execute(self, context):
830+
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
830831
hook = CloudSQLDatabaseHook(
831832
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
832833
gcp_conn_id=self.gcp_conn_id,

tests/test_example_dags.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020
from glob import glob
2121

2222
from airflow.models import DagBag
23+
from tests.test_utils.asserts import assert_queries_count
2324

2425
ROOT_FOLDER = os.path.realpath(
2526
os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)
2627
)
2728

29+
NO_DB_QUERY_EXCEPTION = [
30+
"/airflow/example_dags/example_subdag_operator.py"
31+
]
32+
2833

2934
class TestExampleDags(unittest.TestCase):
3035
def test_should_be_importable(self):
@@ -38,3 +43,19 @@ def test_should_be_importable(self):
3843
)
3944
self.assertEqual(0, len(dagbag.import_errors), f"import_errors={str(dagbag.import_errors)}")
4045
self.assertGreaterEqual(len(dagbag.dag_ids), 1)
46+
47+
def test_should_not_do_database_queries(self):
48+
example_dags = glob(f"{ROOT_FOLDER}/airflow/**/example_dags/example_*.py", recursive=True)
49+
example_dags = [
50+
dag_file
51+
for dag_file in example_dags
52+
if any(not dag_file.endswith(e) for e in NO_DB_QUERY_EXCEPTION)
53+
]
54+
for filepath in example_dags:
55+
relative_filepath = os.path.relpath(filepath, ROOT_FOLDER)
56+
with self.subTest(f"File {relative_filepath} shouldn't do database queries"):
57+
with assert_queries_count(0):
58+
DagBag(
59+
dag_folder=filepath,
60+
include_examples=False,
61+
)

tests/test_utils/asserts.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,53 @@
1616
# under the License.
1717

1818
import re
19+
from contextlib import contextmanager
20+
21+
from sqlalchemy import event
22+
23+
from airflow.settings import engine
1924

2025

2126
def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
2227
def _trim(s):
2328
return re.sub(r"\s+", " ", s.strip())
2429
return case.assertEqual(_trim(first), _trim(second), msg)
30+
31+
32+
class CountQueriesResult:
33+
def __init__(self):
34+
self.count = 0
35+
36+
37+
class CountQueries:
38+
"""
39+
Counts the number of queries sent to Airflow Database in a given context.
40+
41+
Does not support multiple processes. When a new process is started in context, its queries will
42+
not be included.
43+
"""
44+
def __init__(self):
45+
self.result = CountQueriesResult()
46+
47+
def __enter__(self):
48+
event.listen(engine, "after_cursor_execute", self.after_cursor_execute)
49+
return self.result
50+
51+
def __exit__(self, type_, value, traceback):
52+
event.remove(engine, "after_cursor_execute", self.after_cursor_execute)
53+
54+
def after_cursor_execute(self, *args, **kwargs):
55+
self.result.count += 1
56+
57+
58+
count_queries = CountQueries # pylint: disable=invalid-name
59+
60+
61+
@contextmanager
62+
def assert_queries_count(expected_count, message_fmt=None):
63+
with count_queries() as result:
64+
yield None
65+
message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
66+
"The current number is {current_count}."
67+
message = message_fmt.format(current_count=result.count, expected_count=expected_count)
68+
assert expected_count == result.count, message

0 commit comments

Comments
 (0)