|
21 | 21 | import os |
22 | 22 | from time import monotonic, sleep |
23 | 23 | from typing import Any, Dict, Sequence |
24 | | -from urllib.parse import quote, urlencode |
| 24 | +from urllib.parse import quote, urlencode, urljoin |
25 | 25 |
|
26 | 26 | import google.auth |
| 27 | +from aiohttp import ClientSession |
| 28 | +from gcloud.aio.auth import AioSession, Token |
27 | 29 | from google.api_core.retry import exponential_sleep_generator |
28 | 30 | from googleapiclient.discovery import Resource, build |
29 | 31 |
|
30 | | -from airflow.exceptions import AirflowException |
31 | | -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook |
| 32 | +from airflow.exceptions import AirflowException, AirflowNotFoundException |
| 33 | +from airflow.providers.google.common.hooks.base_google import ( |
| 34 | + PROVIDE_PROJECT_ID, |
| 35 | + GoogleBaseAsyncHook, |
| 36 | + GoogleBaseHook, |
| 37 | +) |
32 | 38 |
|
33 | 39 | Operation = Dict[str, Any] |
34 | 40 |
|
@@ -154,12 +160,14 @@ def _cdap_request( |
154 | 160 |
|
155 | 161 | @staticmethod |
156 | 162 | def _check_response_status_and_data(response, message: str) -> None: |
157 | | - if response.status != 200: |
| 163 | + if response.status == 404: |
| 164 | + raise AirflowNotFoundException(message) |
| 165 | + elif response.status != 200: |
158 | 166 | raise AirflowException(message) |
159 | 167 | if response.data is None: |
160 | 168 | raise AirflowException( |
161 | 169 | "Empty response received. Please, check for possible root " |
162 | | - "causes of this behavior either in DAG code or on Cloud Datafusion side" |
| 170 | + "causes of this behavior either in DAG code or on Cloud DataFusion side" |
163 | 171 | ) |
164 | 172 |
|
165 | 173 | def get_conn(self) -> Resource: |
@@ -418,7 +426,7 @@ def start_pipeline( |
418 | 426 | :param pipeline_name: Your pipeline name. |
419 | 427 | :param instance_url: Endpoint on which the REST APIs is accessible for the instance. |
420 | 428 | :param runtime_args: Optional runtime JSON args to be passed to the pipeline |
421 | | - :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID |
| 429 | + :param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID |
422 | 430 | is always default. If your pipeline belongs to an Enterprise edition instance, you |
423 | 431 | can create a namespace. |
424 | 432 | """ |
@@ -469,3 +477,88 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = |
469 | 477 | self._check_response_status_and_data( |
470 | 478 | response, f"Stopping a pipeline failed with code {response.status}" |
471 | 479 | ) |
| 480 | + |
| 481 | + |
| 482 | +class DataFusionAsyncHook(GoogleBaseAsyncHook): |
| 483 | + """Class to get asynchronous hook for DataFusion""" |
| 484 | + |
| 485 | + sync_hook_class = DataFusionHook |
| 486 | + scopes = ["https://www.googleapis.com/auth/cloud-platform"] |
| 487 | + |
| 488 | + @staticmethod |
| 489 | + def _base_url(instance_url: str, namespace: str) -> str: |
| 490 | + return urljoin(f"{instance_url}/", f"v3/namespaces/{quote(namespace)}/apps/") |
| 491 | + |
| 492 | + async def _get_link(self, url: str, session): |
| 493 | + async with Token(scopes=self.scopes) as token: |
| 494 | + session_aio = AioSession(session) |
| 495 | + headers = { |
| 496 | + "Authorization": f"Bearer {await token.get()}", |
| 497 | + } |
| 498 | + try: |
| 499 | + pipeline = await session_aio.get(url=url, headers=headers) |
| 500 | + except AirflowException: |
| 501 | + pass # Because the pipeline may not be visible in system yet |
| 502 | + |
| 503 | + return pipeline |
| 504 | + |
| 505 | + async def get_pipeline( |
| 506 | + self, |
| 507 | + instance_url: str, |
| 508 | + namespace: str, |
| 509 | + pipeline_name: str, |
| 510 | + pipeline_id: str, |
| 511 | + session, |
| 512 | + ): |
| 513 | + base_url_link = self._base_url(instance_url, namespace) |
| 514 | + url = urljoin( |
| 515 | + base_url_link, f"{quote(pipeline_name)}/workflows/DataPipelineWorkflow/runs/{quote(pipeline_id)}" |
| 516 | + ) |
| 517 | + return await self._get_link(url=url, session=session) |
| 518 | + |
| 519 | + async def get_pipeline_status( |
| 520 | + self, |
| 521 | + pipeline_name: str, |
| 522 | + instance_url: str, |
| 523 | + pipeline_id: str, |
| 524 | + namespace: str = "default", |
| 525 | + success_states: list[str] | None = None, |
| 526 | + ) -> str: |
| 527 | + """ |
| 528 | + Gets a Cloud Data Fusion pipeline status asynchronously. |
| 529 | +
|
| 530 | + :param pipeline_name: Your pipeline name. |
| 531 | + :param instance_url: Endpoint on which the REST APIs is accessible for the instance. |
| 532 | + :param pipeline_id: Unique pipeline ID associated with specific pipeline |
| 533 | + :param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID |
| 534 | + is always default. If your pipeline belongs to an Enterprise edition instance, you |
| 535 | + can create a namespace. |
| 536 | + :param success_states: If provided the operator will wait for pipeline to be in one of |
| 537 | + the provided states. |
| 538 | + """ |
| 539 | + success_states = success_states or SUCCESS_STATES |
| 540 | + async with ClientSession() as session: |
| 541 | + try: |
| 542 | + pipeline = await self.get_pipeline( |
| 543 | + instance_url=instance_url, |
| 544 | + namespace=namespace, |
| 545 | + pipeline_name=pipeline_name, |
| 546 | + pipeline_id=pipeline_id, |
| 547 | + session=session, |
| 548 | + ) |
| 549 | + self.log.info("Response pipeline: %s", pipeline) |
| 550 | + pipeline = await pipeline.json(content_type=None) |
| 551 | + current_pipeline_state = pipeline["status"] |
| 552 | + |
| 553 | + if current_pipeline_state in success_states: |
| 554 | + pipeline_status = "success" |
| 555 | + elif current_pipeline_state in FAILURE_STATES: |
| 556 | + pipeline_status = "failed" |
| 557 | + else: |
| 558 | + pipeline_status = "pending" |
| 559 | + except OSError: |
| 560 | + pipeline_status = "pending" |
| 561 | + except Exception as e: |
| 562 | + self.log.info("Retrieving pipeline status finished with errors...") |
| 563 | + pipeline_status = str(e) |
| 564 | + return pipeline_status |
0 commit comments