Skip to content

Commit 432697d

Browse files
authored
allow multiple prefixes in gcs delete/list hooks and operators (#30815)
1 parent 2d40f41 commit 432697d

File tree

4 files changed

+88
-14
lines changed

4 files changed

+88
-14
lines changed

airflow/providers/google/cloud/hooks/gcs.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,63 @@ def delete_bucket(self, bucket_name: str, force: bool = False) -> None:
696696
except NotFound:
697697
self.log.info("Bucket %s not exists", bucket_name)
698698

699-
def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None) -> List:
699+
def list(
700+
self,
701+
bucket_name: str,
702+
versions: bool | None = None,
703+
max_results: int | None = None,
704+
prefix: str | List[str] | None = None,
705+
delimiter: str | None = None,
706+
):
707+
"""
708+
List all objects from the bucket with the given a single prefix or multiple prefixes
709+
710+
:param bucket_name: bucket name
711+
:param versions: if true, list all versions of the objects
712+
:param max_results: max count of items to return in a single page of responses
713+
:param prefix: string or list of strings which filter objects whose name begin with it/them
714+
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
715+
:return: a stream of object names matching the filtering criteria
716+
"""
717+
objects = []
718+
if isinstance(prefix, list):
719+
for prefix_item in prefix:
720+
objects.extend(
721+
self._list(
722+
bucket_name=bucket_name,
723+
versions=versions,
724+
max_results=max_results,
725+
prefix=prefix_item,
726+
delimiter=delimiter,
727+
)
728+
)
729+
else:
730+
objects.extend(
731+
self._list(
732+
bucket_name=bucket_name,
733+
versions=versions,
734+
max_results=max_results,
735+
prefix=prefix,
736+
delimiter=delimiter,
737+
)
738+
)
739+
return objects
740+
741+
def _list(
742+
self,
743+
bucket_name: str,
744+
versions: bool | None = None,
745+
max_results: int | None = None,
746+
prefix: str | None = None,
747+
delimiter: str | None = None,
748+
) -> List:
700749
"""
701750
List all objects from the bucket with the give string prefix in name
702751
703752
:param bucket_name: bucket name
704753
:param versions: if true, list all versions of the objects
705754
:param max_results: max count of items to return in a single page of responses
706-
:param prefix: prefix string which filters objects whose name begin with
707-
this prefix
755+
:param prefix: string which filters objects whose name begin with it
708756
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
709757
:return: a stream of object names matching the filtering criteria
710758
"""

airflow/providers/google/cloud/operators/gcs.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
163163
XCom in the downstream task.
164164
165165
:param bucket: The Google Cloud Storage bucket to find the objects. (templated)
166-
:param prefix: Prefix string which filters objects whose name begin with
167-
this prefix. (templated)
166+
:param prefix: String or list of strings, which filter objects whose name begin with
167+
it/them. (templated)
168168
:param delimiter: The delimiter by which you want to filter the objects. (templated)
169169
For example, to lists the CSV files from in a directory in GCS you would use
170170
delimiter='.csv'.
@@ -206,7 +206,7 @@ def __init__(
206206
self,
207207
*,
208208
bucket: str,
209-
prefix: str | None = None,
209+
prefix: str | list[str] | None = None,
210210
delimiter: str | None = None,
211211
gcp_conn_id: str = "google_cloud_default",
212212
impersonation_chain: str | Sequence[str] | None = None,
@@ -220,14 +220,13 @@ def __init__(
220220
self.impersonation_chain = impersonation_chain
221221

222222
def execute(self, context: Context) -> list:
223-
224223
hook = GCSHook(
225224
gcp_conn_id=self.gcp_conn_id,
226225
impersonation_chain=self.impersonation_chain,
227226
)
228227

229228
self.log.info(
230-
"Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
229+
"Getting list of the files. Bucket: %s; Delimiter: %s; Prefix(es): %s",
231230
self.bucket,
232231
self.delimiter,
233232
self.prefix,
@@ -239,7 +238,6 @@ def execute(self, context: Context) -> list:
239238
uri=self.bucket,
240239
project_id=hook.project_id,
241240
)
242-
243241
return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
244242

245243

@@ -252,8 +250,8 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
252250
:param bucket_name: The GCS bucket to delete from
253251
:param objects: List of objects to delete. These should be the names
254252
of objects in the bucket, not including gs://bucket/
255-
:param prefix: Prefix of objects to delete. All objects matching this
256-
prefix in the bucket will be deleted.
253+
:param prefix: String or list of strings, which filter objects whose name begin with
254+
it/them. (templated)
257255
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
258256
:param impersonation_chain: Optional service account to impersonate using short-term
259257
credentials, or chained list of accounts required to get the access_token
@@ -307,7 +305,6 @@ def execute(self, context: Context) -> None:
307305
objects = self.objects
308306
else:
309307
objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)
310-
311308
self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name)
312309
for object_name in objects:
313310
hook.delete(bucket_name=self.bucket_name, object_name=object_name)

tests/providers/google/cloud/hooks/test_gcs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,36 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file):
758758
]
759759
)
760760

761+
@pytest.mark.parametrize(
762+
"prefix, result",
763+
(
764+
(
765+
"prefix",
766+
[mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)],
767+
),
768+
(
769+
["prefix", "prefix_2"],
770+
[
771+
mock.call(
772+
delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None
773+
),
774+
mock.call(
775+
delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None
776+
),
777+
],
778+
),
779+
),
780+
)
781+
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
782+
def test_list(self, mock_service, prefix, result):
783+
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None
784+
self.gcs_hook.list(
785+
bucket_name="test_bucket",
786+
prefix=prefix,
787+
delimiter=",",
788+
)
789+
assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result
790+
761791
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
762792
def test_list_by_timespans(self, mock_service):
763793
test_bucket = "test_bucket"

tests/providers/google/cloud/operators/test_gcs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
TEST_PROJECT = "test-project"
3939
DELIMITER = ".csv"
4040
PREFIX = "TEST"
41+
PREFIX_2 = "TEST2"
4142
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv", "OTHERTEST1.csv"]
4243
TEST_OBJECT = "dir1/test-object"
4344
LOCAL_FILE_PATH = "/home/airflow/gcp/test-object"
@@ -160,11 +161,9 @@ class TestGoogleCloudStorageListOperator:
160161
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
161162
def test_execute(self, mock_hook):
162163
mock_hook.return_value.list.return_value = MOCK_FILES
163-
164164
operator = GCSListObjectsOperator(
165165
task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
166166
)
167-
168167
files = operator.execute(context=mock.MagicMock())
169168
mock_hook.return_value.list.assert_called_once_with(
170169
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER

0 commit comments

Comments
 (0)