Skip to content

Commit 80a957f

Browse files
author
Tobiasz Kędzierski
authored
Add Dataflow sensors - job metrics (#12039)
1 parent ae7cb4a commit 80a957f

File tree

6 files changed

+213
-4
lines changed

6 files changed

+213
-4
lines changed

airflow/providers/google/cloud/example_dags/example_dataflow.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@
2020
Example Airflow DAG for Google Cloud Dataflow service
2121
"""
2222
import os
23+
from typing import Callable, Dict, List
2324
from urllib.parse import urlparse
2425

2526
from airflow import models
27+
from airflow.exceptions import AirflowException
2628
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
2729
from airflow.providers.google.cloud.operators.dataflow import (
2830
CheckJobRunning,
2931
DataflowCreateJavaJobOperator,
3032
DataflowCreatePythonJobOperator,
3133
DataflowTemplatedJobStartOperator,
3234
)
33-
from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
35+
from airflow.providers.google.cloud.sensors.dataflow import DataflowJobMetricsSensor, DataflowJobStatusSensor
3436
from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
3537
from airflow.utils.dates import days_ago
3638

@@ -159,7 +161,30 @@
159161
location='europe-west3',
160162
)
161163

164+
def check_metric_scalar_gte(metric_name: str, value: int) -> Callable:
165+
"""Check is metric greater than equals to given value."""
166+
167+
def callback(metrics: List[Dict]) -> bool:
168+
dag_native_python_async.log.info("Looking for '%s' >= %d", metric_name, value)
169+
for metric in metrics:
170+
context = metric.get("name", {}).get("context", {})
171+
original_name = context.get("original_name", "")
172+
tentative = context.get("tentative", "")
173+
if original_name == "Service-cpu_num_seconds" and not tentative:
174+
return metric["scalar"] >= value
175+
raise AirflowException(f"Metric '{metric_name}' not found in metrics")
176+
177+
return callback
178+
179+
wait_for_python_job_async_metric = DataflowJobMetricsSensor(
180+
task_id="wait-for-python-job-async-metric",
181+
job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}",
182+
location='europe-west3',
183+
callback=check_metric_scalar_gte(metric_name="Service-cpu_num_seconds", value=100),
184+
)
185+
162186
start_python_job_async >> wait_for_python_job_async_done
187+
start_python_job_async >> wait_for_python_job_async_metric
163188

164189

165190
with models.DAG(

airflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,27 @@ def fetch_job_by_id(self, job_id: str) -> dict:
243243
.execute(num_retries=self._num_retries)
244244
)
245245

246+
def fetch_job_metrics_by_id(self, job_id: str) -> dict:
247+
"""
248+
Helper method to fetch the job metrics with the specified Job ID.
249+
250+
:param job_id: Job ID to get.
251+
:type job_id: str
252+
:return: the JobMetrics. See:
253+
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics
254+
:rtype: dict
255+
"""
256+
result = (
257+
self._dataflow.projects()
258+
.locations()
259+
.jobs()
260+
.getMetrics(projectId=self._project_number, location=self._job_location, jobId=job_id)
261+
.execute(num_retries=self._num_retries)
262+
)
263+
264+
self.log.debug("fetch_job_metrics_by_id %s:\n%s", job_id, result)
265+
return result
266+
246267
def _fetch_all_jobs(self) -> List[dict]:
247268
request = (
248269
self._dataflow.projects()
@@ -1101,3 +1122,31 @@ def get_job(
11011122
location=location,
11021123
)
11031124
return jobs_controller.fetch_job_by_id(job_id)
1125+
1126+
@GoogleBaseHook.fallback_to_default_project_id
1127+
def fetch_job_metrics_by_id(
1128+
self,
1129+
job_id: str,
1130+
project_id: str,
1131+
location: str = DEFAULT_DATAFLOW_LOCATION,
1132+
) -> dict:
1133+
"""
1134+
Gets the job metrics with the specified Job ID.
1135+
1136+
:param job_id: Job ID to get.
1137+
:type job_id: str
1138+
:param project_id: Optional, the Google Cloud project ID in which to start a job.
1139+
If set to None or missing, the default project_id from the Google Cloud connection is used.
1140+
:type project_id:
1141+
:param location: The location of the Dataflow job (for example europe-west1). See:
1142+
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1143+
:return: the JobMetrics. See:
1144+
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/JobMetrics
1145+
:rtype: dict
1146+
"""
1147+
jobs_controller = _DataflowJobsController(
1148+
dataflow=self.get_conn(),
1149+
project_number=project_id,
1150+
location=location,
1151+
)
1152+
return jobs_controller.fetch_job_metrics_by_id(job_id)

airflow/providers/google/cloud/operators/dataflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ def __init__( # pylint: disable=too-many-arguments
909909
self.cancel_timeout = cancel_timeout
910910
self.wait_until_finished = wait_until_finished
911911
self.job_id = None
912-
self.hook = None
912+
self.hook: Optional[DataflowHook] = None
913913

914914
def execute(self, context):
915915
"""Execute the python dataflow job."""

airflow/providers/google/cloud/sensors/dataflow.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818
"""This module contains a Google Cloud Dataflow sensor."""
19-
from typing import Optional, Sequence, Set, Union
19+
from typing import Callable, Optional, Sequence, Set, Union
2020

2121
from airflow.exceptions import AirflowException
2222
from airflow.providers.google.cloud.hooks.dataflow import (
@@ -116,3 +116,75 @@ def poke(self, context: dict) -> bool:
116116
raise AirflowException(f"Job with id '{self.job_id}' is already in terminal state: {job_status}")
117117

118118
return False
119+
120+
121+
class DataflowJobMetricsSensor(BaseSensorOperator):
122+
"""
123+
Checks the metrics of a job in Google Cloud Dataflow.
124+
125+
:param job_id: ID of the job to be checked.
126+
:type job_id: str
127+
:param callback: callback which is called with list of read job metrics
128+
See:
129+
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate
130+
:type callback: callable
131+
:param project_id: Optional, the Google Cloud project ID in which to start a job.
132+
If set to None or missing, the default project_id from the Google Cloud connection is used.
133+
:type project_id: str
134+
:param location: The location of the Dataflow job (for example europe-west1). See:
135+
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
136+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
137+
:type gcp_conn_id: str
138+
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
139+
if any. For this to work, the service account making the request must have
140+
domain-wide delegation enabled.
141+
:type delegate_to: str
142+
:param impersonation_chain: Optional service account to impersonate using short-term
143+
credentials, or chained list of accounts required to get the access_token
144+
of the last account in the list, which will be impersonated in the request.
145+
If set as a string, the account must grant the originating account
146+
the Service Account Token Creator IAM role.
147+
If set as a sequence, the identities from the list must grant
148+
Service Account Token Creator IAM role to the directly preceding identity, with first
149+
account from the list granting this role to the originating account (templated).
150+
:type impersonation_chain: Union[str, Sequence[str]]
151+
"""
152+
153+
template_fields = ['job_id']
154+
155+
@apply_defaults
156+
def __init__(
157+
self,
158+
*,
159+
job_id: str,
160+
callback: Callable[[dict], bool],
161+
project_id: Optional[str] = None,
162+
location: str = DEFAULT_DATAFLOW_LOCATION,
163+
gcp_conn_id: str = 'google_cloud_default',
164+
delegate_to: Optional[str] = None,
165+
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
166+
**kwargs,
167+
) -> None:
168+
super().__init__(**kwargs)
169+
self.job_id = job_id
170+
self.project_id = project_id
171+
self.callback = callback
172+
self.location = location
173+
self.gcp_conn_id = gcp_conn_id
174+
self.delegate_to = delegate_to
175+
self.impersonation_chain = impersonation_chain
176+
self.hook: Optional[DataflowHook] = None
177+
178+
def poke(self, context: dict) -> bool:
179+
self.hook = DataflowHook(
180+
gcp_conn_id=self.gcp_conn_id,
181+
delegate_to=self.delegate_to,
182+
impersonation_chain=self.impersonation_chain,
183+
)
184+
result = self.hook.fetch_job_metrics_by_id(
185+
job_id=self.job_id,
186+
project_id=self.project_id,
187+
location=self.location,
188+
)
189+
190+
return self.callback(result["metrics"])

tests/providers/google/cloud/hooks/test_dataflow.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,37 @@ def test_get_job(self, mock_conn, mock_dataflowjob):
648648
)
649649
method_fetch_job_by_id.assert_called_once_with(TEST_JOB_ID)
650650

651+
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
652+
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
653+
def test_fetch_job_metrics_by_id(self, mock_conn, mock_dataflowjob):
654+
method_fetch_job_metrics_by_id = mock_dataflowjob.return_value.fetch_job_metrics_by_id
655+
656+
self.dataflow_hook.fetch_job_metrics_by_id(
657+
job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION
658+
)
659+
mock_conn.assert_called_once()
660+
mock_dataflowjob.assert_called_once_with(
661+
dataflow=mock_conn.return_value,
662+
project_number=TEST_PROJECT_ID,
663+
location=TEST_LOCATION,
664+
)
665+
method_fetch_job_metrics_by_id.assert_called_once_with(TEST_JOB_ID)
666+
667+
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
668+
def test_fetch_job_metrics_by_id_controller(self, mock_conn):
669+
method_get_metrics = (
670+
mock_conn.return_value.projects.return_value.locations.return_value.jobs.return_value.getMetrics
671+
)
672+
self.dataflow_hook.fetch_job_metrics_by_id(
673+
job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION
674+
)
675+
676+
mock_conn.assert_called_once()
677+
method_get_metrics.return_value.execute.assert_called_once_with(num_retries=0)
678+
method_get_metrics.assert_called_once_with(
679+
jobId=TEST_JOB_ID, projectId=TEST_PROJECT_ID, location=TEST_LOCATION
680+
)
681+
651682

652683
class TestDataflowTemplateHook(unittest.TestCase):
653684
def setUp(self):

tests/providers/google/cloud/sensors/test_dataflow.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from airflow.exceptions import AirflowException
2525
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
26-
from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
26+
from airflow.providers.google.cloud.sensors.dataflow import DataflowJobMetricsSensor, DataflowJobStatusSensor
2727

2828
TEST_TASK_ID = "tesk-id"
2929
TEST_JOB_ID = "test_job_id"
@@ -98,3 +98,35 @@ def test_poke_raise_exception(self, mock_hook):
9898
mock_get_job.assert_called_once_with(
9999
job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION
100100
)
101+
102+
103+
class TestDataflowJobMetricsSensor(unittest.TestCase):
104+
@mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook")
105+
def test_poke(self, mock_hook):
106+
mock_fetch_job_metrics_by_id = mock_hook.return_value.fetch_job_metrics_by_id
107+
callback = mock.MagicMock()
108+
109+
task = DataflowJobMetricsSensor(
110+
task_id=TEST_TASK_ID,
111+
job_id=TEST_JOB_ID,
112+
callback=callback,
113+
location=TEST_LOCATION,
114+
project_id=TEST_PROJECT_ID,
115+
gcp_conn_id=TEST_GCP_CONN_ID,
116+
delegate_to=TEST_DELEGATE_TO,
117+
impersonation_chain=TEST_IMPERSONATION_CHAIN,
118+
)
119+
results = task.poke(mock.MagicMock())
120+
121+
self.assertEqual(callback.return_value, results)
122+
123+
mock_hook.assert_called_once_with(
124+
gcp_conn_id=TEST_GCP_CONN_ID,
125+
delegate_to=TEST_DELEGATE_TO,
126+
impersonation_chain=TEST_IMPERSONATION_CHAIN,
127+
)
128+
mock_fetch_job_metrics_by_id.assert_called_once_with(
129+
job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION
130+
)
131+
mock_fetch_job_metrics_by_id.return_value.__getitem__.assert_called_once_with("metrics")
132+
callback.assert_called_once_with(mock_fetch_job_metrics_by_id.return_value.__getitem__.return_value)

0 commit comments

Comments
 (0)