|
33 | 33 | BUCKET = 'gs://test' |
34 | 34 | JSON_FILENAME = 'test_{}.ndjson' |
35 | 35 | CSV_FILENAME = 'test_{}.csv' |
| 36 | +SCHEMA = [ |
| 37 | + {'mode': 'REQUIRED', 'name': 'some_str', 'type': 'FLOAT'}, |
| 38 | + {'mode': 'REQUIRED', 'name': 'some_num', 'type': 'TIMESTAMP'} |
| 39 | +] |
36 | 40 |
|
37 | 41 | ROWS = [ |
38 | 42 | ('mock_row_content_1', 42), |
|
65 | 69 | b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ', |
66 | 70 | b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]' |
67 | 71 | ] |
| 72 | +CUSTOM_SCHEMA_JSON = [ |
| 73 | + b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ', |
| 74 | + b'{"mode": "REQUIRED", "name": "some_num", "type": "TIMESTAMP"}]' |
| 75 | +] |
68 | 76 |
|
69 | 77 |
|
70 | 78 | class TestMySqlToGoogleCloudStorageOperator(unittest.TestCase): |
@@ -293,6 +301,36 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab |
293 | 301 | # once for the file and once for the schema |
294 | 302 | self.assertEqual(2, gcs_hook_mock.upload.call_count) |
295 | 303 |
|
| 304 | + @mock.patch('airflow.providers.google.cloud.operators.mysql_to_gcs.MySqlHook') |
| 305 | + @mock.patch('airflow.providers.google.cloud.operators.sql_to_gcs.GCSHook') |
| 306 | + def test_schema_file_with_custom_schema(self, gcs_hook_mock_class, mysql_hook_mock_class): |
| 307 | + """Test writing schema files with customized schema""" |
| 308 | + mysql_hook_mock = mysql_hook_mock_class.return_value |
| 309 | + mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) |
| 310 | + mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION |
| 311 | + |
| 312 | + gcs_hook_mock = gcs_hook_mock_class.return_value |
| 313 | + |
| 314 | + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument |
| 315 | + if obj == SCHEMA_FILENAME: |
| 316 | + self.assertFalse(gzip) |
| 317 | + with open(tmp_filename, 'rb') as file: |
| 318 | + self.assertEqual(b''.join(CUSTOM_SCHEMA_JSON), file.read()) |
| 319 | + |
| 320 | + gcs_hook_mock.upload.side_effect = _assert_upload |
| 321 | + |
| 322 | + op = MySQLToGCSOperator( |
| 323 | + task_id=TASK_ID, |
| 324 | + sql=SQL, |
| 325 | + bucket=BUCKET, |
| 326 | + filename=JSON_FILENAME, |
| 327 | + schema_filename=SCHEMA_FILENAME, |
| 328 | + schema=SCHEMA) |
| 329 | + op.execute(None) |
| 330 | + |
| 331 | + # once for the file and once for the schema |
| 332 | + self.assertEqual(2, gcs_hook_mock.upload.call_count) |
| 333 | + |
296 | 334 | @mock.patch('airflow.providers.google.cloud.operators.mysql_to_gcs.MySqlHook') |
297 | 335 | @mock.patch('airflow.providers.google.cloud.operators.sql_to_gcs.GCSHook') |
298 | 336 | def test_query_with_error(self, mock_gcs_hook, mock_mysql_hook): |
|
0 commit comments