|
16 | 16 | # specific language governing permissions and limitations |
17 | 17 | # under the License. |
18 | 18 | """This module contains a Google Cloud Dataflow sensor.""" |
19 | | -from typing import Optional, Sequence, Set, Union |
| 19 | +from typing import Callable, Optional, Sequence, Set, Union |
20 | 20 |
|
21 | 21 | from airflow.exceptions import AirflowException |
22 | 22 | from airflow.providers.google.cloud.hooks.dataflow import ( |
@@ -116,3 +116,75 @@ def poke(self, context: dict) -> bool: |
116 | 116 | raise AirflowException(f"Job with id '{self.job_id}' is already in terminal state: {job_status}") |
117 | 117 |
|
118 | 118 | return False |
| 119 | + |
| 120 | + |
| 121 | +class DataflowJobMetricsSensor(BaseSensorOperator): |
| 122 | + """ |
| 123 | + Checks the metrics of a job in Google Cloud Dataflow. |
| 124 | +
|
| 125 | + :param job_id: ID of the job to be checked. |
| 126 | + :type job_id: str |
| 127 | + :param callback: callback which is called with list of read job metrics |
| 128 | + See: |
| 129 | + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate |
| 130 | + :type callback: callable |
| 131 | + :param project_id: Optional, the Google Cloud project ID in which to start a job. |
| 132 | + If set to None or missing, the default project_id from the Google Cloud connection is used. |
| 133 | + :type project_id: str |
| 134 | + :param location: The location of the Dataflow job (for example europe-west1). See: |
| 135 | + https://cloud.google.com/dataflow/docs/concepts/regional-endpoints |
| 136 | + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. |
| 137 | + :type gcp_conn_id: str |
| 138 | + :param delegate_to: The account to impersonate using domain-wide delegation of authority, |
| 139 | + if any. For this to work, the service account making the request must have |
| 140 | + domain-wide delegation enabled. |
| 141 | + :type delegate_to: str |
| 142 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 143 | + credentials, or chained list of accounts required to get the access_token |
| 144 | + of the last account in the list, which will be impersonated in the request. |
| 145 | + If set as a string, the account must grant the originating account |
| 146 | + the Service Account Token Creator IAM role. |
| 147 | + If set as a sequence, the identities from the list must grant |
| 148 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 149 | + account from the list granting this role to the originating account (templated). |
| 150 | + :type impersonation_chain: Union[str, Sequence[str]] |
| 151 | + """ |
| 152 | + |
| 153 | + template_fields = ['job_id'] |
| 154 | + |
| 155 | + @apply_defaults |
| 156 | + def __init__( |
| 157 | + self, |
| 158 | + *, |
| 159 | + job_id: str, |
| 160 | + callback: Callable[[dict], bool], |
| 161 | + project_id: Optional[str] = None, |
| 162 | + location: str = DEFAULT_DATAFLOW_LOCATION, |
| 163 | + gcp_conn_id: str = 'google_cloud_default', |
| 164 | + delegate_to: Optional[str] = None, |
| 165 | + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, |
| 166 | + **kwargs, |
| 167 | + ) -> None: |
| 168 | + super().__init__(**kwargs) |
| 169 | + self.job_id = job_id |
| 170 | + self.project_id = project_id |
| 171 | + self.callback = callback |
| 172 | + self.location = location |
| 173 | + self.gcp_conn_id = gcp_conn_id |
| 174 | + self.delegate_to = delegate_to |
| 175 | + self.impersonation_chain = impersonation_chain |
| 176 | + self.hook: Optional[DataflowHook] = None |
| 177 | + |
| 178 | + def poke(self, context: dict) -> bool: |
| 179 | + self.hook = DataflowHook( |
| 180 | + gcp_conn_id=self.gcp_conn_id, |
| 181 | + delegate_to=self.delegate_to, |
| 182 | + impersonation_chain=self.impersonation_chain, |
| 183 | + ) |
| 184 | + result = self.hook.fetch_job_metrics_by_id( |
| 185 | + job_id=self.job_id, |
| 186 | + project_id=self.project_id, |
| 187 | + location=self.location, |
| 188 | + ) |
| 189 | + |
| 190 | + return self.callback(result["metrics"]) |
0 commit comments