Skip to content

Commit cc76229

Browse files
authored
feat: add Hook Level Lineage support for GCSHook (#42507)
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
1 parent 69af185 commit cc76229

File tree

9 files changed

+372
-7
lines changed

9 files changed

+372
-7
lines changed

generated/provider_dependencies.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@
625625
"google": {
626626
"deps": [
627627
"PyOpenSSL>=23.0.0",
628-
"apache-airflow-providers-common-compat>=1.1.0",
628+
"apache-airflow-providers-common-compat>=1.2.1",
629629
"apache-airflow-providers-common-sql>=1.7.2",
630630
"apache-airflow>=2.8.0",
631631
"asgiref>=3.5.2",

providers/src/airflow/providers/google/datasets/__init__.py renamed to providers/src/airflow/providers/google/assets/__init__.py

File renamed without changes.

providers/src/airflow/providers/google/datasets/bigquery.py renamed to providers/src/airflow/providers/google/assets/bigquery.py

File renamed without changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from typing import TYPE_CHECKING
20+
21+
from airflow.providers.common.compat.assets import Asset
22+
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url
23+
24+
if TYPE_CHECKING:
25+
from urllib.parse import SplitResult
26+
27+
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
28+
29+
30+
def create_asset(*, bucket: str, key: str, extra: dict | None = None) -> Asset:
31+
return Asset(uri=f"gs://{bucket}/{key}", extra=extra)
32+
33+
34+
def sanitize_uri(uri: SplitResult) -> SplitResult:
35+
if not uri.netloc:
36+
raise ValueError("URI format gs:// must contain a bucket name")
37+
return uri
38+
39+
40+
def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset:
41+
"""Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
42+
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
43+
44+
bucket, key = _parse_gcs_url(asset.uri)
45+
return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/")

providers/src/airflow/providers/google/cloud/hooks/gcs.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from requests import Session
4444

4545
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
46+
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
4647
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
4748
from airflow.providers.google.common.consts import CLIENT_INFO
4849
from airflow.providers.google.common.hooks.base_google import (
@@ -214,6 +215,16 @@ def copy(
214215
destination_object = source_bucket.copy_blob( # type: ignore[attr-defined]
215216
blob=source_object, destination_bucket=destination_bucket, new_name=destination_object
216217
)
218+
get_hook_lineage_collector().add_input_asset(
219+
context=self,
220+
scheme="gs",
221+
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
222+
)
223+
get_hook_lineage_collector().add_output_asset(
224+
context=self,
225+
scheme="gs",
226+
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name}, # type: ignore[union-attr]
227+
)
217228

218229
self.log.info(
219230
"Object %s in bucket %s copied to object %s in bucket %s",
@@ -267,6 +278,16 @@ def rewrite(
267278
).rewrite(source=source_object, token=token)
268279

269280
self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten)
281+
get_hook_lineage_collector().add_input_asset(
282+
context=self,
283+
scheme="gs",
284+
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
285+
)
286+
get_hook_lineage_collector().add_output_asset(
287+
context=self,
288+
scheme="gs",
289+
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object}, # type: ignore[attr-defined]
290+
)
270291
self.log.info(
271292
"Object %s in bucket %s rewritten to object %s in bucket %s",
272293
source_object.name, # type: ignore[attr-defined]
@@ -345,9 +366,18 @@ def download(
345366

346367
if filename:
347368
blob.download_to_filename(filename, timeout=timeout)
369+
get_hook_lineage_collector().add_input_asset(
370+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
371+
)
372+
get_hook_lineage_collector().add_output_asset(
373+
context=self, scheme="file", asset_kwargs={"path": filename}
374+
)
348375
self.log.info("File downloaded to %s", filename)
349376
return filename
350377
else:
378+
get_hook_lineage_collector().add_input_asset(
379+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
380+
)
351381
return blob.download_as_bytes()
352382

353383
except GoogleCloudError:
@@ -555,6 +585,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
555585
_call_with_retry(
556586
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)
557587
)
588+
get_hook_lineage_collector().add_input_asset(
589+
context=self, scheme="file", asset_kwargs={"path": filename}
590+
)
558591

559592
if gzip:
560593
os.remove(filename)
@@ -576,6 +609,10 @@ def _call_with_retry(f: Callable[[], None]) -> None:
576609
else:
577610
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")
578611

612+
get_hook_lineage_collector().add_output_asset(
613+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
614+
)
615+
579616
def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool:
580617
"""
581618
Check for the existence of a file in Google Cloud Storage.
@@ -691,6 +728,9 @@ def delete(self, bucket_name: str, object_name: str) -> None:
691728
bucket = client.bucket(bucket_name)
692729
blob = bucket.blob(blob_name=object_name)
693730
blob.delete()
731+
get_hook_lineage_collector().add_input_asset(
732+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
733+
)
694734

695735
self.log.info("Blob %s deleted.", object_name)
696736

@@ -1198,9 +1238,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec
11981238
client = self.get_conn()
11991239
bucket = client.bucket(bucket_name)
12001240
destination_blob = bucket.blob(destination_object)
1201-
destination_blob.compose(
1202-
sources=[bucket.blob(blob_name=source_object) for source_object in source_objects]
1241+
source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects]
1242+
destination_blob.compose(sources=source_blobs)
1243+
get_hook_lineage_collector().add_output_asset(
1244+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": destination_blob.name}
12031245
)
1246+
for single_source_blob in source_blobs:
1247+
get_hook_lineage_collector().add_input_asset(
1248+
context=self,
1249+
scheme="gs",
1250+
asset_kwargs={"bucket": bucket.name, "key": single_source_blob.name},
1251+
)
12041252

12051253
self.log.info("Completed successfully.")
12061254

providers/src/airflow/providers/google/provider.yaml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ versions:
9797

9898
dependencies:
9999
- apache-airflow>=2.8.0
100-
- apache-airflow-providers-common-compat>=1.1.0
100+
- apache-airflow-providers-common-compat>=1.2.1
101101
- apache-airflow-providers-common-sql>=1.7.2
102102
- asgiref>=3.5.2
103103
- dill>=0.2.3
@@ -777,15 +777,23 @@ asset-uris:
777777
- schemes: [gcp]
778778
handler: null
779779
- schemes: [bigquery]
780-
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
780+
handler: airflow.providers.google.assets.bigquery.sanitize_uri
781+
- schemes: [gs]
782+
handler: airflow.providers.google.assets.gcs.sanitize_uri
783+
factory: airflow.providers.google.assets.gcs.create_asset
784+
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage
781785

782786
# dataset has been renamed to asset in Airflow 3.0
783787
# This is kept for backward compatibility.
784788
dataset-uris:
785789
- schemes: [gcp]
786790
handler: null
787791
- schemes: [bigquery]
788-
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
792+
handler: airflow.providers.google.assets.bigquery.sanitize_uri
793+
- schemes: [gs]
794+
handler: airflow.providers.google.assets.gcs.sanitize_uri
795+
factory: airflow.providers.google.assets.gcs.create_asset
796+
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage
789797

790798
hooks:
791799
- integration-name: Google Ads

providers/tests/google/assets/test_bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from airflow.providers.google.datasets.bigquery import sanitize_uri
24+
from airflow.providers.google.assets.bigquery import sanitize_uri
2525

2626

2727
def test_sanitize_uri_pass() -> None:
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import urllib.parse
20+
21+
import pytest
22+
23+
from airflow.providers.common.compat.assets import Asset
24+
from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri
25+
26+
27+
def test_sanitize_uri():
28+
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt"))
29+
result = sanitize_uri(uri)
30+
assert result.scheme == "gs"
31+
assert result.netloc == "bucket"
32+
assert result.path == "/dir/file.txt"
33+
34+
35+
def test_sanitize_uri_no_netloc():
36+
with pytest.raises(ValueError):
37+
sanitize_uri(urllib.parse.urlsplit("gs://"))
38+
39+
40+
def test_sanitize_uri_no_path():
41+
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket"))
42+
result = sanitize_uri(uri)
43+
assert result.scheme == "gs"
44+
assert result.netloc == "bucket"
45+
assert result.path == ""
46+
47+
48+
def test_create_asset():
49+
assert create_asset(bucket="test-bucket", key="test-path") == Asset(uri="gs://test-bucket/test-path")
50+
assert create_asset(bucket="test-bucket", key="test-dir/test-path") == Asset(
51+
uri="gs://test-bucket/test-dir/test-path"
52+
)
53+
54+
55+
def test_sanitize_uri_trailing_slash():
56+
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/"))
57+
result = sanitize_uri(uri)
58+
assert result.scheme == "gs"
59+
assert result.netloc == "bucket"
60+
assert result.path == "/"
61+
62+
63+
def test_convert_asset_to_openlineage_valid():
64+
uri = "gs://bucket/dir/file.txt"
65+
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
66+
assert ol_dataset.namespace == "gs://bucket"
67+
assert ol_dataset.name == "dir/file.txt"
68+
69+
70+
@pytest.mark.parametrize("uri", ("gs://bucket", "gs://bucket/"))
71+
def test_convert_asset_to_openlineage_no_path(uri):
72+
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
73+
assert ol_dataset.namespace == "gs://bucket"
74+
assert ol_dataset.name == "/"

0 commit comments

Comments
 (0)