Skip to content

Commit e07a42e

Browse files
Lee-Wpankajkoti
andauthored
Check cluster state before defer Dataproc operators to trigger (#36892)
While operating a data proc cluster in deferrable mode, the condition might already be met (created, deleted, updated) before we defer the task into the trigger. This PR intends to check thecluster status before deferring the task to trigger. --------- Co-authored-by: Pankaj Koti <pankajkoti699@gmail.com>
1 parent d48985c commit e07a42e

File tree

2 files changed

+185
-24
lines changed

2 files changed

+185
-24
lines changed

airflow/providers/google/cloud/operators/dataproc.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
721721
def execute(self, context: Context) -> dict:
722722
self.log.info("Creating cluster: %s", self.cluster_name)
723723
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
724+
724725
# Save data required to display extra link no matter what the cluster status will be
725726
project_id = self.project_id or hook.project_id
726727
if project_id:
@@ -731,6 +732,7 @@ def execute(self, context: Context) -> dict:
731732
project_id=project_id,
732733
region=self.region,
733734
)
735+
734736
try:
735737
# First try to create a new cluster
736738
operation = self._create_cluster(hook)
@@ -741,17 +743,24 @@ def execute(self, context: Context) -> dict:
741743
self.log.info("Cluster created.")
742744
return Cluster.to_dict(cluster)
743745
else:
744-
self.defer(
745-
trigger=DataprocClusterTrigger(
746-
cluster_name=self.cluster_name,
747-
project_id=self.project_id,
748-
region=self.region,
749-
gcp_conn_id=self.gcp_conn_id,
750-
impersonation_chain=self.impersonation_chain,
751-
polling_interval_seconds=self.polling_interval_seconds,
752-
),
753-
method_name="execute_complete",
746+
cluster = hook.get_cluster(
747+
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
754748
)
749+
if cluster.status.state == cluster.status.State.RUNNING:
750+
self.log.info("Cluster created.")
751+
return Cluster.to_dict(cluster)
752+
else:
753+
self.defer(
754+
trigger=DataprocClusterTrigger(
755+
cluster_name=self.cluster_name,
756+
project_id=self.project_id,
757+
region=self.region,
758+
gcp_conn_id=self.gcp_conn_id,
759+
impersonation_chain=self.impersonation_chain,
760+
polling_interval_seconds=self.polling_interval_seconds,
761+
),
762+
method_name="execute_complete",
763+
)
755764
except AlreadyExists:
756765
if not self.use_if_exists:
757766
raise
@@ -1016,6 +1025,16 @@ def execute(self, context: Context) -> None:
10161025
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
10171026
self.log.info("Cluster deleted.")
10181027
else:
1028+
try:
1029+
hook.get_cluster(
1030+
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
1031+
)
1032+
except NotFound:
1033+
self.log.info("Cluster deleted.")
1034+
return
1035+
except Exception as e:
1036+
raise AirflowException(str(e))
1037+
10191038
end_time: float = time.time() + self.timeout
10201039
self.defer(
10211040
trigger=DataprocDeleteClusterTrigger(
@@ -2480,17 +2499,21 @@ def execute(self, context: Context):
24802499
if not self.deferrable:
24812500
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
24822501
else:
2483-
self.defer(
2484-
trigger=DataprocClusterTrigger(
2485-
cluster_name=self.cluster_name,
2486-
project_id=self.project_id,
2487-
region=self.region,
2488-
gcp_conn_id=self.gcp_conn_id,
2489-
impersonation_chain=self.impersonation_chain,
2490-
polling_interval_seconds=self.polling_interval_seconds,
2491-
),
2492-
method_name="execute_complete",
2502+
cluster = hook.get_cluster(
2503+
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
24932504
)
2505+
if cluster.status.state != cluster.status.State.RUNNING:
2506+
self.defer(
2507+
trigger=DataprocClusterTrigger(
2508+
cluster_name=self.cluster_name,
2509+
project_id=self.project_id,
2510+
region=self.region,
2511+
gcp_conn_id=self.gcp_conn_id,
2512+
impersonation_chain=self.impersonation_chain,
2513+
polling_interval_seconds=self.polling_interval_seconds,
2514+
),
2515+
method_name="execute_complete",
2516+
)
24942517
self.log.info("Updated %s cluster.", self.cluster_name)
24952518

24962519
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:

tests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import pytest
2424
from google.api_core.exceptions import AlreadyExists, NotFound
2525
from google.api_core.retry import Retry
26-
from google.cloud.dataproc_v1 import Batch, JobStatus
26+
from google.cloud import dataproc
27+
from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
2728

2829
from airflow.exceptions import (
2930
AirflowException,
@@ -579,7 +580,7 @@ def test_build_with_flex_migs(self):
579580
assert CONFIG_WITH_FLEX_MIG == cluster
580581

581582

582-
class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
583+
class TestDataprocCreateClusterOperator(DataprocClusterTestBase):
583584
def test_deprecation_warning(self):
584585
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
585586
op = DataprocCreateClusterOperator(
@@ -883,6 +884,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
883884
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
884885
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
885886

887+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
888+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
889+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
890+
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
891+
cluster = Cluster(
892+
cluster_name="test_cluster",
893+
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
894+
)
895+
mock_hook.return_value.create_cluster.return_value = cluster
896+
mock_hook.return_value.get_cluster.return_value = cluster
897+
operator = DataprocCreateClusterOperator(
898+
task_id=TASK_ID,
899+
region=GCP_REGION,
900+
project_id=GCP_PROJECT,
901+
cluster_config=CONFIG,
902+
labels=LABELS,
903+
cluster_name=CLUSTER_NAME,
904+
delete_on_error=True,
905+
metadata=METADATA,
906+
gcp_conn_id=GCP_CONN_ID,
907+
impersonation_chain=IMPERSONATION_CHAIN,
908+
retry=RETRY,
909+
timeout=TIMEOUT,
910+
deferrable=True,
911+
)
912+
913+
operator.execute(mock.MagicMock())
914+
assert not mock_defer.called
915+
916+
mock_hook.assert_called_once_with(
917+
gcp_conn_id=GCP_CONN_ID,
918+
impersonation_chain=IMPERSONATION_CHAIN,
919+
)
920+
921+
mock_hook.return_value.create_cluster.assert_called_once_with(
922+
region=GCP_REGION,
923+
project_id=GCP_PROJECT,
924+
cluster_config=CONFIG,
925+
request_id=None,
926+
labels=LABELS,
927+
cluster_name=CLUSTER_NAME,
928+
virtual_cluster_config=None,
929+
retry=RETRY,
930+
timeout=TIMEOUT,
931+
metadata=METADATA,
932+
)
933+
mock_hook.return_value.wait_for_operation.assert_not_called()
934+
886935

887936
@pytest.mark.db_test
888937
@pytest.mark.need_serialized_dag
@@ -1100,6 +1149,47 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
11001149
assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
11011150
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
11021151

1152+
@mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer"))
1153+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
1154+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
1155+
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
1156+
mock_hook.return_value.create_cluster.return_value = None
1157+
mock_hook.return_value.get_cluster.side_effect = NotFound("test")
1158+
operator = DataprocDeleteClusterOperator(
1159+
task_id=TASK_ID,
1160+
region=GCP_REGION,
1161+
project_id=GCP_PROJECT,
1162+
cluster_name=CLUSTER_NAME,
1163+
request_id=REQUEST_ID,
1164+
gcp_conn_id=GCP_CONN_ID,
1165+
retry=RETRY,
1166+
timeout=TIMEOUT,
1167+
metadata=METADATA,
1168+
impersonation_chain=IMPERSONATION_CHAIN,
1169+
deferrable=True,
1170+
)
1171+
1172+
operator.execute(mock.MagicMock())
1173+
1174+
mock_hook.assert_called_once_with(
1175+
gcp_conn_id=GCP_CONN_ID,
1176+
impersonation_chain=IMPERSONATION_CHAIN,
1177+
)
1178+
1179+
mock_hook.return_value.delete_cluster.assert_called_once_with(
1180+
project_id=GCP_PROJECT,
1181+
region=GCP_REGION,
1182+
cluster_name=CLUSTER_NAME,
1183+
cluster_uuid=None,
1184+
request_id=REQUEST_ID,
1185+
retry=RETRY,
1186+
timeout=TIMEOUT,
1187+
metadata=METADATA,
1188+
)
1189+
1190+
mock_hook.return_value.wait_for_operation.assert_not_called()
1191+
assert not mock_defer.called
1192+
11031193

11041194
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
11051195
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1240,8 +1330,8 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
12401330
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
12411331

12421332
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
1243-
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
1244-
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
1333+
@mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer"))
1334+
@mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job"))
12451335
def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook):
12461336
mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
12471337
job_status = mock_hook.return_value.get_job.return_value.status
@@ -1498,6 +1588,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
14981588
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
14991589
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
15001590

1591+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
1592+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
1593+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
1594+
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
1595+
cluster = Cluster(
1596+
cluster_name="test_cluster",
1597+
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
1598+
)
1599+
mock_hook.return_value.update_cluster.return_value = cluster
1600+
mock_hook.return_value.get_cluster.return_value = cluster
1601+
operator = DataprocUpdateClusterOperator(
1602+
task_id=TASK_ID,
1603+
region=GCP_REGION,
1604+
cluster_name=CLUSTER_NAME,
1605+
cluster=CLUSTER,
1606+
update_mask=UPDATE_MASK,
1607+
request_id=REQUEST_ID,
1608+
graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
1609+
project_id=GCP_PROJECT,
1610+
gcp_conn_id=GCP_CONN_ID,
1611+
retry=RETRY,
1612+
timeout=TIMEOUT,
1613+
metadata=METADATA,
1614+
impersonation_chain=IMPERSONATION_CHAIN,
1615+
deferrable=True,
1616+
)
1617+
1618+
operator.execute(mock.MagicMock())
1619+
1620+
mock_hook.assert_called_once_with(
1621+
gcp_conn_id=GCP_CONN_ID,
1622+
impersonation_chain=IMPERSONATION_CHAIN,
1623+
)
1624+
mock_hook.return_value.update_cluster.assert_called_once_with(
1625+
project_id=GCP_PROJECT,
1626+
region=GCP_REGION,
1627+
cluster_name=CLUSTER_NAME,
1628+
cluster=CLUSTER,
1629+
update_mask=UPDATE_MASK,
1630+
request_id=REQUEST_ID,
1631+
graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
1632+
retry=RETRY,
1633+
timeout=TIMEOUT,
1634+
metadata=METADATA,
1635+
)
1636+
mock_hook.return_value.wait_for_operation.assert_not_called()
1637+
assert not mock_defer.called
1638+
15011639

15021640
@pytest.mark.db_test
15031641
@pytest.mark.need_serialized_dag

0 commit comments

Comments
 (0)