Skip to content

Commit e926bb9

Browse files
Add deferrable mode to DataFusionStartPipelineOperator (#28690)
1 parent 5fcdd32 commit e926bb9

File tree

16 files changed

+1313
-215
lines changed

16 files changed

+1313
-215
lines changed

airflow/providers/google/cloud/hooks/datafusion.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,20 @@
2121
import os
2222
from time import monotonic, sleep
2323
from typing import Any, Dict, Sequence
24-
from urllib.parse import quote, urlencode
24+
from urllib.parse import quote, urlencode, urljoin
2525

2626
import google.auth
27+
from aiohttp import ClientSession
28+
from gcloud.aio.auth import AioSession, Token
2729
from google.api_core.retry import exponential_sleep_generator
2830
from googleapiclient.discovery import Resource, build
2931

30-
from airflow.exceptions import AirflowException
31-
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
32+
from airflow.exceptions import AirflowException, AirflowNotFoundException
33+
from airflow.providers.google.common.hooks.base_google import (
34+
PROVIDE_PROJECT_ID,
35+
GoogleBaseAsyncHook,
36+
GoogleBaseHook,
37+
)
3238

3339
Operation = Dict[str, Any]
3440

@@ -154,12 +160,14 @@ def _cdap_request(
154160

155161
@staticmethod
156162
def _check_response_status_and_data(response, message: str) -> None:
157-
if response.status != 200:
163+
if response.status == 404:
164+
raise AirflowNotFoundException(message)
165+
elif response.status != 200:
158166
raise AirflowException(message)
159167
if response.data is None:
160168
raise AirflowException(
161169
"Empty response received. Please, check for possible root "
162-
"causes of this behavior either in DAG code or on Cloud Datafusion side"
170+
"causes of this behavior either in DAG code or on Cloud DataFusion side"
163171
)
164172

165173
def get_conn(self) -> Resource:
@@ -418,7 +426,7 @@ def start_pipeline(
418426
:param pipeline_name: Your pipeline name.
419427
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
420428
:param runtime_args: Optional runtime JSON args to be passed to the pipeline
421-
:param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID
429+
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
422430
is always default. If your pipeline belongs to an Enterprise edition instance, you
423431
can create a namespace.
424432
"""
@@ -469,3 +477,88 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str =
469477
self._check_response_status_and_data(
470478
response, f"Stopping a pipeline failed with code {response.status}"
471479
)
480+
481+
482+
class DataFusionAsyncHook(GoogleBaseAsyncHook):
483+
"""Class to get asynchronous hook for DataFusion"""
484+
485+
sync_hook_class = DataFusionHook
486+
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
487+
488+
@staticmethod
489+
def _base_url(instance_url: str, namespace: str) -> str:
490+
return urljoin(f"{instance_url}/", f"v3/namespaces/{quote(namespace)}/apps/")
491+
492+
async def _get_link(self, url: str, session):
493+
async with Token(scopes=self.scopes) as token:
494+
session_aio = AioSession(session)
495+
headers = {
496+
"Authorization": f"Bearer {await token.get()}",
497+
}
498+
try:
499+
pipeline = await session_aio.get(url=url, headers=headers)
500+
except AirflowException:
501+
pass # Because the pipeline may not be visible in system yet
502+
503+
return pipeline
504+
505+
async def get_pipeline(
506+
self,
507+
instance_url: str,
508+
namespace: str,
509+
pipeline_name: str,
510+
pipeline_id: str,
511+
session,
512+
):
513+
base_url_link = self._base_url(instance_url, namespace)
514+
url = urljoin(
515+
base_url_link, f"{quote(pipeline_name)}/workflows/DataPipelineWorkflow/runs/{quote(pipeline_id)}"
516+
)
517+
return await self._get_link(url=url, session=session)
518+
519+
async def get_pipeline_status(
520+
self,
521+
pipeline_name: str,
522+
instance_url: str,
523+
pipeline_id: str,
524+
namespace: str = "default",
525+
success_states: list[str] | None = None,
526+
) -> str:
527+
"""
528+
Gets a Cloud Data Fusion pipeline status asynchronously.
529+
530+
:param pipeline_name: Your pipeline name.
531+
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
532+
:param pipeline_id: Unique pipeline ID associated with specific pipeline
533+
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
534+
is always default. If your pipeline belongs to an Enterprise edition instance, you
535+
can create a namespace.
536+
:param success_states: If provided the operator will wait for pipeline to be in one of
537+
the provided states.
538+
"""
539+
success_states = success_states or SUCCESS_STATES
540+
async with ClientSession() as session:
541+
try:
542+
pipeline = await self.get_pipeline(
543+
instance_url=instance_url,
544+
namespace=namespace,
545+
pipeline_name=pipeline_name,
546+
pipeline_id=pipeline_id,
547+
session=session,
548+
)
549+
self.log.info("Response pipeline: %s", pipeline)
550+
pipeline = await pipeline.json(content_type=None)
551+
current_pipeline_state = pipeline["status"]
552+
553+
if current_pipeline_state in success_states:
554+
pipeline_status = "success"
555+
elif current_pipeline_state in FAILURE_STATES:
556+
pipeline_status = "failed"
557+
else:
558+
pipeline_status = "pending"
559+
except OSError:
560+
pipeline_status = "pending"
561+
except Exception as e:
562+
self.log.info("Retrieving pipeline status finished with errors...")
563+
pipeline_status = str(e)
564+
return pipeline_status
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
"""This module contains Google Compute Engine links."""
19+
from __future__ import annotations
20+
21+
from typing import TYPE_CHECKING, ClassVar
22+
23+
from airflow.models import BaseOperatorLink, XCom
24+
25+
if TYPE_CHECKING:
26+
from airflow.models import BaseOperator
27+
from airflow.models.taskinstance import TaskInstanceKey
28+
from airflow.utils.context import Context
29+
30+
31+
BASE_LINK = "https://console.cloud.google.com/data-fusion"
32+
DATAFUSION_INSTANCE_LINK = BASE_LINK + "/locations/{region}/instances/{instance_name}?project={project_id}"
33+
DATAFUSION_PIPELINES_LINK = "{uri}/cdap/ns/default/pipelines"
34+
DATAFUSION_PIPELINE_LINK = "{uri}/pipelines/ns/default/view/{pipeline_name}"
35+
36+
37+
class BaseGoogleLink(BaseOperatorLink):
38+
"""
39+
Override the base logic to prevent adding 'https://console.cloud.google.com'
40+
in front of every link where uri is used
41+
"""
42+
43+
name: ClassVar[str]
44+
key: ClassVar[str]
45+
format_str: ClassVar[str]
46+
47+
def get_link(
48+
self,
49+
operator: BaseOperator,
50+
*,
51+
ti_key: TaskInstanceKey,
52+
) -> str:
53+
conf = XCom.get_value(key=self.key, ti_key=ti_key)
54+
if not conf:
55+
return ""
56+
if self.format_str.startswith("http"):
57+
return self.format_str.format(**conf)
58+
return self.format_str.format(**conf)
59+
60+
61+
class DataFusionInstanceLink(BaseGoogleLink):
62+
"""Helper class for constructing Data Fusion Instance link"""
63+
64+
name = "Data Fusion Instance"
65+
key = "instance_conf"
66+
format_str = DATAFUSION_INSTANCE_LINK
67+
68+
@staticmethod
69+
def persist(
70+
context: Context,
71+
task_instance: BaseOperator,
72+
location: str,
73+
instance_name: str,
74+
project_id: str,
75+
):
76+
task_instance.xcom_push(
77+
context=context,
78+
key=DataFusionInstanceLink.key,
79+
value={
80+
"region": location,
81+
"instance_name": instance_name,
82+
"project_id": project_id,
83+
},
84+
)
85+
86+
87+
class DataFusionPipelineLink(BaseGoogleLink):
88+
"""Helper class for constructing Data Fusion Pipeline link"""
89+
90+
name = "Data Fusion Pipeline"
91+
key = "pipeline_conf"
92+
format_str = DATAFUSION_PIPELINE_LINK
93+
94+
@staticmethod
95+
def persist(
96+
context: Context,
97+
task_instance: BaseOperator,
98+
uri: str,
99+
pipeline_name: str,
100+
):
101+
task_instance.xcom_push(
102+
context=context,
103+
key=DataFusionPipelineLink.key,
104+
value={
105+
"uri": uri,
106+
"pipeline_name": pipeline_name,
107+
},
108+
)
109+
110+
111+
class DataFusionPipelinesLink(BaseGoogleLink):
112+
"""Helper class for constructing list of Data Fusion Pipelines link"""
113+
114+
name = "Data Fusion Pipelines List"
115+
key = "pipelines_conf"
116+
format_str = DATAFUSION_PIPELINES_LINK
117+
118+
@staticmethod
119+
def persist(
120+
context: Context,
121+
task_instance: BaseOperator,
122+
uri: str,
123+
):
124+
task_instance.xcom_push(
125+
context=context,
126+
key=DataFusionPipelinesLink.key,
127+
value={
128+
"uri": uri,
129+
},
130+
)

0 commit comments

Comments
 (0)