2020
2121import abc
2222import json
23+ import os
2324from tempfile import NamedTemporaryFile
2425from typing import TYPE_CHECKING , Sequence
2526
@@ -77,6 +78,10 @@ class BaseSQLToGCSOperator(BaseOperator):
7778 account from the list granting this role to the originating account (templated).
7879 :param upload_metadata: whether to upload the row count metadata as blob metadata
7980 :param exclude_columns: set of columns to exclude from transmission
81+ :param partition_columns: list of columns to use for file partitioning. In order to use
82+ this parameter, you must sort your dataset by partition_columns. Do this by
83+ passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
84+ with a hive style partitioning directory structure (templated).
8085 """
8186
8287 template_fields : Sequence [str ] = (
@@ -87,6 +92,7 @@ class BaseSQLToGCSOperator(BaseOperator):
8792 "schema" ,
8893 "parameters" ,
8994 "impersonation_chain" ,
95+ "partition_columns" ,
9096 )
9197 template_ext : Sequence [str ] = (".sql" ,)
9298 template_fields_renderers = {"sql" : "sql" }
@@ -111,7 +117,8 @@ def __init__(
111117 delegate_to : str | None = None ,
112118 impersonation_chain : str | Sequence [str ] | None = None ,
113119 upload_metadata : bool = False ,
114- exclude_columns = None ,
120+ exclude_columns : set | None = None ,
121+ partition_columns : list | None = None ,
115122 ** kwargs ,
116123 ) -> None :
117124 super ().__init__ (** kwargs )
@@ -135,8 +142,16 @@ def __init__(
135142 self .impersonation_chain = impersonation_chain
136143 self .upload_metadata = upload_metadata
137144 self .exclude_columns = exclude_columns
145+ self .partition_columns = partition_columns
138146
139147 def execute (self , context : Context ):
148+ if self .partition_columns :
149+ self .log .info (
150+ f"Found partition columns: { ',' .join (self .partition_columns )} . "
151+ "Assuming the SQL statement is properly sorted by these columns in "
152+ "ascending or descending order."
153+ )
154+
140155 self .log .info ("Executing query" )
141156 cursor = self .query ()
142157
@@ -158,6 +173,7 @@ def execute(self, context: Context):
158173 total_files = 0
159174 self .log .info ("Writing local data files" )
160175 for file_to_upload in self ._write_local_data_files (cursor ):
176+
161177 # Flush file before uploading
162178 file_to_upload ["file_handle" ].flush ()
163179
@@ -204,36 +220,56 @@ def _write_local_data_files(self, cursor):
204220 names in GCS, and values are file handles to local files that
205221 contain the data for the GCS objects.
206222 """
207- import os
208-
209223 org_schema = list (map (lambda schema_tuple : schema_tuple [0 ], cursor .description ))
210224 schema = [column for column in org_schema if column not in self .exclude_columns ]
211225
212226 col_type_dict = self ._get_col_type_dict ()
213227 file_no = 0
214-
215- tmp_file_handle = NamedTemporaryFile (delete = True )
216- if self .export_format == "csv" :
217- file_mime_type = "text/csv"
218- elif self .export_format == "parquet" :
219- file_mime_type = "application/octet-stream"
220- else :
221- file_mime_type = "application/json"
222- file_to_upload = {
223- "file_name" : self .filename .format (file_no ),
224- "file_handle" : tmp_file_handle ,
225- "file_mime_type" : file_mime_type ,
226- "file_row_count" : 0 ,
227- }
228+ file_mime_type = self ._get_file_mime_type ()
229+ file_to_upload , tmp_file_handle = self ._get_file_to_upload (file_mime_type , file_no )
228230
229231 if self .export_format == "csv" :
230232 csv_writer = self ._configure_csv_file (tmp_file_handle , schema )
231233 if self .export_format == "parquet" :
232234 parquet_schema = self ._convert_parquet_schema (cursor )
233235 parquet_writer = self ._configure_parquet_file (tmp_file_handle , parquet_schema )
234236
237+ prev_partition_values = None
238+ curr_partition_values = None
235239 for row in cursor :
240+ if self .partition_columns :
241+ row_dict = dict (zip (schema , row ))
242+ curr_partition_values = tuple (
243+ [row_dict .get (partition_column , "" ) for partition_column in self .partition_columns ]
244+ )
245+
246+ if prev_partition_values is None :
247+ # We haven't set prev_partition_values before. Set to current
248+ prev_partition_values = curr_partition_values
249+
250+ elif prev_partition_values != curr_partition_values :
251+ # If the partition values differ, write the current local file out
252+ # Yield first before we write the current record
253+ file_no += 1
254+
255+ if self .export_format == "parquet" :
256+ parquet_writer .close ()
257+
258+ file_to_upload ["partition_values" ] = prev_partition_values
259+ yield file_to_upload
260+ file_to_upload , tmp_file_handle = self ._get_file_to_upload (file_mime_type , file_no )
261+ if self .export_format == "csv" :
262+ csv_writer = self ._configure_csv_file (tmp_file_handle , schema )
263+ if self .export_format == "parquet" :
264+ parquet_writer = self ._configure_parquet_file (tmp_file_handle , parquet_schema )
265+
266+ # Reset previous to current after writing out the file
267+ prev_partition_values = curr_partition_values
268+
269+ # Incrementing file_row_count after partition yield ensures all rows are written
236270 file_to_upload ["file_row_count" ] += 1
271+
272+ # Proceed to write the row to the localfile
237273 if self .export_format == "csv" :
238274 row = self .convert_types (schema , col_type_dict , row )
239275 if self .null_marker is not None :
@@ -268,24 +304,44 @@ def _write_local_data_files(self, cursor):
268304
269305 if self .export_format == "parquet" :
270306 parquet_writer .close ()
307+
308+ file_to_upload ["partition_values" ] = curr_partition_values
271309 yield file_to_upload
272- tmp_file_handle = NamedTemporaryFile (delete = True )
273- file_to_upload = {
274- "file_name" : self .filename .format (file_no ),
275- "file_handle" : tmp_file_handle ,
276- "file_mime_type" : file_mime_type ,
277- "file_row_count" : 0 ,
278- }
310+ file_to_upload , tmp_file_handle = self ._get_file_to_upload (file_mime_type , file_no )
279311 if self .export_format == "csv" :
280312 csv_writer = self ._configure_csv_file (tmp_file_handle , schema )
281313 if self .export_format == "parquet" :
282314 parquet_writer = self ._configure_parquet_file (tmp_file_handle , parquet_schema )
315+
283316 if self .export_format == "parquet" :
284317 parquet_writer .close ()
285318 # Last file may have 0 rows, don't yield if empty
286319 if file_to_upload ["file_row_count" ] > 0 :
320+ file_to_upload ["partition_values" ] = curr_partition_values
287321 yield file_to_upload
288322
323+ def _get_file_to_upload (self , file_mime_type , file_no ):
324+ """Returns a dictionary that represents the file to upload"""
325+ tmp_file_handle = NamedTemporaryFile (delete = True )
326+ return (
327+ {
328+ "file_name" : self .filename .format (file_no ),
329+ "file_handle" : tmp_file_handle ,
330+ "file_mime_type" : file_mime_type ,
331+ "file_row_count" : 0 ,
332+ },
333+ tmp_file_handle ,
334+ )
335+
336+ def _get_file_mime_type (self ):
337+ if self .export_format == "csv" :
338+ file_mime_type = "text/csv"
339+ elif self .export_format == "parquet" :
340+ file_mime_type = "application/octet-stream"
341+ else :
342+ file_mime_type = "application/json"
343+ return file_mime_type
344+
289345 def _configure_csv_file (self , file_handle , schema ):
290346 """Configure a csv writer with the file_handle and write schema
291347 as headers for the new file.
@@ -400,9 +456,19 @@ def _upload_to_gcs(self, file_to_upload):
400456 if is_data_file and self .upload_metadata :
401457 metadata = {"row_count" : file_to_upload ["file_row_count" ]}
402458
459+ object_name = file_to_upload .get ("file_name" )
460+ if is_data_file and self .partition_columns :
461+ # Add partition column values to object_name
462+ partition_values = file_to_upload .get ("partition_values" )
463+ head_path , tail_path = os .path .split (object_name )
464+ partition_subprefix = [
465+ f"{ col } ={ val } " for col , val in zip (self .partition_columns , partition_values )
466+ ]
467+ object_name = os .path .join (head_path , * partition_subprefix , tail_path )
468+
403469 hook .upload (
404470 self .bucket ,
405- file_to_upload . get ( "file_name" ) ,
471+ object_name ,
406472 file_to_upload .get ("file_handle" ).name ,
407473 mime_type = file_to_upload .get ("file_mime_type" ),
408474 gzip = self .gzip if is_data_file else False ,
0 commit comments