@@ -297,7 +297,7 @@ def __init__(
297297 * ,
298298 bucket_name : str ,
299299 objects : list [str ] | None = None ,
300- prefix : str | None = None ,
300+ prefix : str | list [ str ] | None = None ,
301301 gcp_conn_id : str = "google_cloud_default" ,
302302 impersonation_chain : str | Sequence [str ] | None = None ,
303303 ** kwargs ,
@@ -309,12 +309,14 @@ def __init__(
309309 self .impersonation_chain = impersonation_chain
310310
311311 if objects is None and prefix is None :
312- err_message = "(Task {task_id}) Either object or prefix should be set. Both are None." .format (
312+ err_message = "(Task {task_id}) Either objects or prefix should be set. Both are None." .format (
313313 ** kwargs
314314 )
315315 raise ValueError (err_message )
316+ if objects is not None and prefix is not None :
317+ err_message = "(Task {task_id}) Objects or prefix should be set. Both provided." .format (** kwargs )
318+ raise ValueError (err_message )
316319
317- self ._objects : list [str ] = []
318320 super ().__init__ (** kwargs )
319321
320322 def execute (self , context : Context ) -> None :
@@ -324,15 +326,14 @@ def execute(self, context: Context) -> None:
324326 )
325327
326328 if self .objects is not None :
327- self . _objects = self .objects
329+ objects = self .objects
328330 else :
329- self . _objects = hook .list (bucket_name = self .bucket_name , prefix = self .prefix )
330- self .log .info ("Deleting %s objects from %s" , len (self . _objects ), self .bucket_name )
331- for object_name in self . _objects :
331+ objects = hook .list (bucket_name = self .bucket_name , prefix = self .prefix )
332+ self .log .info ("Deleting %s objects from %s" , len (objects ), self .bucket_name )
333+ for object_name in objects :
332334 hook .delete (bucket_name = self .bucket_name , object_name = object_name )
333335
334- def get_openlineage_facets_on_complete (self , task_instance ):
335- """Implement on_complete as execute() resolves object names."""
336+ def get_openlineage_facets_on_start (self ):
336337 from openlineage .client .facet import (
337338 LifecycleStateChange ,
338339 LifecycleStateChangeDatasetFacet ,
@@ -342,8 +343,17 @@ def get_openlineage_facets_on_complete(self, task_instance):
342343
343344 from airflow .providers .openlineage .extractors import OperatorLineage
344345
345- if not self ._objects :
346- return OperatorLineage ()
346+ objects = []
347+ if self .objects is not None :
348+ objects = self .objects
349+ elif self .prefix is not None :
350+ prefixes = [self .prefix ] if isinstance (self .prefix , str ) else self .prefix
351+ for pref in prefixes :
352+ # Use parent if not a file (dot not in name) and not a dir (ends with slash)
353+ if "." not in pref .split ("/" )[- 1 ] and not pref .endswith ("/" ):
354+ pref = Path (pref ).parent .as_posix ()
355+ pref = "/" if pref in ("." , "" , "/" ) else pref .rstrip ("/" )
356+ objects .append (pref )
347357
348358 bucket_url = f"gs://{ self .bucket_name } "
349359 input_datasets = [
@@ -360,7 +370,7 @@ def get_openlineage_facets_on_complete(self, task_instance):
360370 )
361371 },
362372 )
363- for object_name in self . _objects
373+ for object_name in objects
364374 ]
365375
366376 return OperatorLineage (inputs = input_datasets )
0 commit comments