Skip to content

Commit 7ef75d2

Browse files
authored
[AIRFLOW-7117] Honor self.schema in sql_to_gcs as schema to upload (#8049)
1 parent cc9b1bc commit 7ef75d2

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

airflow/providers/google/cloud/operators/sql_to_gcs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,28 @@ def _get_col_type_dict(self):
251251
def _write_local_schema_file(self, cursor):
252252
"""
253253
Takes a cursor, and writes the BigQuery schema for the results to a
254-
local file system.
254+
local file system. Schema for database will be read from cursor if
255+
not specified.
255256
256257
:return: A dictionary where key is a filename to be used as an object
257258
name in GCS, and values are file handles to local files that
258259
contains the BigQuery schema fields in .json format.
259260
"""
260-
schema = [self.field_to_bigquery(field) for field in cursor.description]
261+
if self.schema:
262+
self.log.info("Using user schema")
263+
schema = self.schema
264+
else:
265+
self.log.info("Starts generating schema")
266+
schema = [self.field_to_bigquery(field) for field in cursor.description]
267+
268+
if isinstance(schema, list):
269+
schema = json.dumps(schema, sort_keys=True)
261270

262271
self.log.info('Using schema for %s', self.schema_filename)
263272
self.log.debug("Current schema: %s", schema)
273+
264274
tmp_schema_file_handle = NamedTemporaryFile(delete=True)
265-
tmp_schema_file_handle.write(json.dumps(schema, sort_keys=True).encode('utf-8'))
275+
tmp_schema_file_handle.write(schema.encode('utf-8'))
266276
schema_file_to_upload = {
267277
'file_name': self.schema_filename,
268278
'file_handle': tmp_schema_file_handle,

tests/providers/google/cloud/operators/test_mysql_to_gcs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
BUCKET = 'gs://test'
3434
JSON_FILENAME = 'test_{}.ndjson'
3535
CSV_FILENAME = 'test_{}.csv'
36+
SCHEMA = [
37+
{'mode': 'REQUIRED', 'name': 'some_str', 'type': 'FLOAT'},
38+
{'mode': 'REQUIRED', 'name': 'some_num', 'type': 'TIMESTAMP'}
39+
]
3640

3741
ROWS = [
3842
('mock_row_content_1', 42),
@@ -65,6 +69,10 @@
6569
b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ',
6670
b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]'
6771
]
72+
CUSTOM_SCHEMA_JSON = [
73+
b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ',
74+
b'{"mode": "REQUIRED", "name": "some_num", "type": "TIMESTAMP"}]'
75+
]
6876

6977

7078
class TestMySqlToGoogleCloudStorageOperator(unittest.TestCase):
@@ -293,6 +301,36 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
293301
# once for the file and once for the schema
294302
self.assertEqual(2, gcs_hook_mock.upload.call_count)
295303

304+
@mock.patch('airflow.providers.google.cloud.operators.mysql_to_gcs.MySqlHook')
305+
@mock.patch('airflow.providers.google.cloud.operators.sql_to_gcs.GCSHook')
306+
def test_schema_file_with_custom_schema(self, gcs_hook_mock_class, mysql_hook_mock_class):
307+
"""Test writing schema files with customized schema"""
308+
mysql_hook_mock = mysql_hook_mock_class.return_value
309+
mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
310+
mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
311+
312+
gcs_hook_mock = gcs_hook_mock_class.return_value
313+
314+
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
315+
if obj == SCHEMA_FILENAME:
316+
self.assertFalse(gzip)
317+
with open(tmp_filename, 'rb') as file:
318+
self.assertEqual(b''.join(CUSTOM_SCHEMA_JSON), file.read())
319+
320+
gcs_hook_mock.upload.side_effect = _assert_upload
321+
322+
op = MySQLToGCSOperator(
323+
task_id=TASK_ID,
324+
sql=SQL,
325+
bucket=BUCKET,
326+
filename=JSON_FILENAME,
327+
schema_filename=SCHEMA_FILENAME,
328+
schema=SCHEMA)
329+
op.execute(None)
330+
331+
# once for the file and once for the schema
332+
self.assertEqual(2, gcs_hook_mock.upload.call_count)
333+
296334
@mock.patch('airflow.providers.google.cloud.operators.mysql_to_gcs.MySqlHook')
297335
@mock.patch('airflow.providers.google.cloud.operators.sql_to_gcs.GCSHook')
298336
def test_query_with_error(self, mock_gcs_hook, mock_mysql_hook):

0 commit comments

Comments
 (0)