Skip to content

Commit 35a8ffc

Browse files
vchiapaikeoeladkal
andauthored
Support partition_columns in BaseSQLToGCSOperator (#28677)
* Support partition_columns in BaseSQLToGCSOperator Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com>
1 parent 07a17ba commit 35a8ffc

File tree

2 files changed

+188
-25
lines changed

2 files changed

+188
-25
lines changed

airflow/providers/google/cloud/transfers/sql_to_gcs.py

Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import abc
2222
import json
23+
import os
2324
from tempfile import NamedTemporaryFile
2425
from typing import TYPE_CHECKING, Sequence
2526

@@ -77,6 +78,10 @@ class BaseSQLToGCSOperator(BaseOperator):
7778
account from the list granting this role to the originating account (templated).
7879
:param upload_metadata: whether to upload the row count metadata as blob metadata
7980
:param exclude_columns: set of columns to exclude from transmission
81+
:param partition_columns: list of columns to use for file partitioning. In order to use
82+
this parameter, you must sort your dataset by partition_columns. Do this by
83+
passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
84+
with a hive style partitioning directory structure (templated).
8085
"""
8186

8287
template_fields: Sequence[str] = (
@@ -87,6 +92,7 @@ class BaseSQLToGCSOperator(BaseOperator):
8792
"schema",
8893
"parameters",
8994
"impersonation_chain",
95+
"partition_columns",
9096
)
9197
template_ext: Sequence[str] = (".sql",)
9298
template_fields_renderers = {"sql": "sql"}
@@ -111,7 +117,8 @@ def __init__(
111117
delegate_to: str | None = None,
112118
impersonation_chain: str | Sequence[str] | None = None,
113119
upload_metadata: bool = False,
114-
exclude_columns=None,
120+
exclude_columns: set | None = None,
121+
partition_columns: list | None = None,
115122
**kwargs,
116123
) -> None:
117124
super().__init__(**kwargs)
@@ -135,8 +142,16 @@ def __init__(
135142
self.impersonation_chain = impersonation_chain
136143
self.upload_metadata = upload_metadata
137144
self.exclude_columns = exclude_columns
145+
self.partition_columns = partition_columns
138146

139147
def execute(self, context: Context):
148+
if self.partition_columns:
149+
self.log.info(
150+
f"Found partition columns: {','.join(self.partition_columns)}. "
151+
"Assuming the SQL statement is properly sorted by these columns in "
152+
"ascending or descending order."
153+
)
154+
140155
self.log.info("Executing query")
141156
cursor = self.query()
142157

@@ -158,6 +173,7 @@ def execute(self, context: Context):
158173
total_files = 0
159174
self.log.info("Writing local data files")
160175
for file_to_upload in self._write_local_data_files(cursor):
176+
161177
# Flush file before uploading
162178
file_to_upload["file_handle"].flush()
163179

@@ -204,36 +220,56 @@ def _write_local_data_files(self, cursor):
204220
names in GCS, and values are file handles to local files that
205221
contain the data for the GCS objects.
206222
"""
207-
import os
208-
209223
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
210224
schema = [column for column in org_schema if column not in self.exclude_columns]
211225

212226
col_type_dict = self._get_col_type_dict()
213227
file_no = 0
214-
215-
tmp_file_handle = NamedTemporaryFile(delete=True)
216-
if self.export_format == "csv":
217-
file_mime_type = "text/csv"
218-
elif self.export_format == "parquet":
219-
file_mime_type = "application/octet-stream"
220-
else:
221-
file_mime_type = "application/json"
222-
file_to_upload = {
223-
"file_name": self.filename.format(file_no),
224-
"file_handle": tmp_file_handle,
225-
"file_mime_type": file_mime_type,
226-
"file_row_count": 0,
227-
}
228+
file_mime_type = self._get_file_mime_type()
229+
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
228230

229231
if self.export_format == "csv":
230232
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
231233
if self.export_format == "parquet":
232234
parquet_schema = self._convert_parquet_schema(cursor)
233235
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
234236

237+
prev_partition_values = None
238+
curr_partition_values = None
235239
for row in cursor:
240+
if self.partition_columns:
241+
row_dict = dict(zip(schema, row))
242+
curr_partition_values = tuple(
243+
[row_dict.get(partition_column, "") for partition_column in self.partition_columns]
244+
)
245+
246+
if prev_partition_values is None:
247+
# We haven't set prev_partition_values before. Set to current
248+
prev_partition_values = curr_partition_values
249+
250+
elif prev_partition_values != curr_partition_values:
251+
# If the partition values differ, write the current local file out
252+
# Yield first before we write the current record
253+
file_no += 1
254+
255+
if self.export_format == "parquet":
256+
parquet_writer.close()
257+
258+
file_to_upload["partition_values"] = prev_partition_values
259+
yield file_to_upload
260+
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
261+
if self.export_format == "csv":
262+
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
263+
if self.export_format == "parquet":
264+
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
265+
266+
# Reset previous to current after writing out the file
267+
prev_partition_values = curr_partition_values
268+
269+
# Incrementing file_row_count after partition yield ensures all rows are written
236270
file_to_upload["file_row_count"] += 1
271+
272+
# Proceed to write the row to the localfile
237273
if self.export_format == "csv":
238274
row = self.convert_types(schema, col_type_dict, row)
239275
if self.null_marker is not None:
@@ -268,24 +304,44 @@ def _write_local_data_files(self, cursor):
268304

269305
if self.export_format == "parquet":
270306
parquet_writer.close()
307+
308+
file_to_upload["partition_values"] = curr_partition_values
271309
yield file_to_upload
272-
tmp_file_handle = NamedTemporaryFile(delete=True)
273-
file_to_upload = {
274-
"file_name": self.filename.format(file_no),
275-
"file_handle": tmp_file_handle,
276-
"file_mime_type": file_mime_type,
277-
"file_row_count": 0,
278-
}
310+
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
279311
if self.export_format == "csv":
280312
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
281313
if self.export_format == "parquet":
282314
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
315+
283316
if self.export_format == "parquet":
284317
parquet_writer.close()
285318
# Last file may have 0 rows, don't yield if empty
286319
if file_to_upload["file_row_count"] > 0:
320+
file_to_upload["partition_values"] = curr_partition_values
287321
yield file_to_upload
288322

323+
def _get_file_to_upload(self, file_mime_type, file_no):
324+
"""Returns a dictionary that represents the file to upload"""
325+
tmp_file_handle = NamedTemporaryFile(delete=True)
326+
return (
327+
{
328+
"file_name": self.filename.format(file_no),
329+
"file_handle": tmp_file_handle,
330+
"file_mime_type": file_mime_type,
331+
"file_row_count": 0,
332+
},
333+
tmp_file_handle,
334+
)
335+
336+
def _get_file_mime_type(self):
337+
if self.export_format == "csv":
338+
file_mime_type = "text/csv"
339+
elif self.export_format == "parquet":
340+
file_mime_type = "application/octet-stream"
341+
else:
342+
file_mime_type = "application/json"
343+
return file_mime_type
344+
289345
def _configure_csv_file(self, file_handle, schema):
290346
"""Configure a csv writer with the file_handle and write schema
291347
as headers for the new file.
@@ -400,9 +456,19 @@ def _upload_to_gcs(self, file_to_upload):
400456
if is_data_file and self.upload_metadata:
401457
metadata = {"row_count": file_to_upload["file_row_count"]}
402458

459+
object_name = file_to_upload.get("file_name")
460+
if is_data_file and self.partition_columns:
461+
# Add partition column values to object_name
462+
partition_values = file_to_upload.get("partition_values")
463+
head_path, tail_path = os.path.split(object_name)
464+
partition_subprefix = [
465+
f"{col}={val}" for col, val in zip(self.partition_columns, partition_values)
466+
]
467+
object_name = os.path.join(head_path, *partition_subprefix, tail_path)
468+
403469
hook.upload(
404470
self.bucket,
405-
file_to_upload.get("file_name"),
471+
object_name,
406472
file_to_upload.get("file_handle").name,
407473
mime_type=file_to_upload.get("file_mime_type"),
408474
gzip=self.gzip if is_data_file else False,

tests/providers/google/cloud/transfers/test_sql_to_gcs.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
OUTPUT_DF = pd.DataFrame([["convert_type_return_value"] * 3] * 3, columns=COLUMNS)
6363

6464
EXCLUDE_COLUMNS = set("column_c")
65+
PARTITION_COLUMNS = ["column_b", "column_c"]
6566
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
6667
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
6768
[["convert_type_return_value"] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
@@ -305,6 +306,74 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m
305306
)
306307
mock_close.assert_called_once()
307308

309+
mock_query.reset_mock()
310+
mock_flush.reset_mock()
311+
mock_upload.reset_mock()
312+
mock_close.reset_mock()
313+
cursor_mock.reset_mock()
314+
315+
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
316+
317+
# Test partition columns
318+
operator = DummySQLToGCSOperator(
319+
sql=SQL,
320+
bucket=BUCKET,
321+
filename=FILENAME,
322+
task_id=TASK_ID,
323+
export_format="parquet",
324+
schema=SCHEMA,
325+
partition_columns=PARTITION_COLUMNS,
326+
)
327+
result = operator.execute(context=dict())
328+
329+
assert result == {
330+
"bucket": "TEST-BUCKET-1",
331+
"total_row_count": 3,
332+
"total_files": 3,
333+
"files": [
334+
{
335+
"file_name": "test_results_0.csv",
336+
"file_mime_type": "application/octet-stream",
337+
"file_row_count": 1,
338+
},
339+
{
340+
"file_name": "test_results_1.csv",
341+
"file_mime_type": "application/octet-stream",
342+
"file_row_count": 1,
343+
},
344+
{
345+
"file_name": "test_results_2.csv",
346+
"file_mime_type": "application/octet-stream",
347+
"file_row_count": 1,
348+
},
349+
],
350+
}
351+
352+
mock_query.assert_called_once()
353+
assert mock_flush.call_count == 3
354+
assert mock_close.call_count == 3
355+
mock_upload.assert_has_calls(
356+
[
357+
mock.call(
358+
BUCKET,
359+
f"column_b={row[1]}/column_c={row[2]}/test_results_{i}.csv",
360+
TMP_FILE_NAME,
361+
mime_type="application/octet-stream",
362+
gzip=False,
363+
metadata=None,
364+
)
365+
for i, row in enumerate(INPUT_DATA)
366+
]
367+
)
368+
369+
mock_query.reset_mock()
370+
mock_flush.reset_mock()
371+
mock_upload.reset_mock()
372+
mock_close.reset_mock()
373+
cursor_mock.reset_mock()
374+
375+
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
376+
308377
# Test null marker
309378
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
310379
mock_convert_type.return_value = None
@@ -423,3 +492,31 @@ def test__write_local_data_files_json_with_exclude_columns(self):
423492
file.flush()
424493
df = pd.read_json(file.name, orient="records", lines=True)
425494
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)
495+
496+
def test__write_local_data_files_parquet_with_partition_columns(self):
497+
op = DummySQLToGCSOperator(
498+
sql=SQL,
499+
bucket=BUCKET,
500+
filename=FILENAME,
501+
task_id=TASK_ID,
502+
schema_filename=SCHEMA_FILE,
503+
export_format="parquet",
504+
gzip=False,
505+
schema=SCHEMA,
506+
gcp_conn_id="google_cloud_default",
507+
partition_columns=PARTITION_COLUMNS,
508+
)
509+
cursor = MagicMock()
510+
cursor.__iter__.return_value = INPUT_DATA
511+
cursor.description = CURSOR_DESCRIPTION
512+
513+
local_data_files = op._write_local_data_files(cursor)
514+
concat_dfs = []
515+
for local_data_file in local_data_files:
516+
file = local_data_file["file_handle"]
517+
file.flush()
518+
df = pd.read_parquet(file.name)
519+
concat_dfs.append(df)
520+
521+
concat_df = pd.concat(concat_dfs, ignore_index=True)
522+
assert concat_df.equals(OUTPUT_DF)

0 commit comments

Comments
 (0)