|
23 | 23 | import pytest |
24 | 24 | from google.api_core.exceptions import AlreadyExists, NotFound |
25 | 25 | from google.api_core.retry import Retry |
26 | | -from google.cloud.dataproc_v1 import Batch, JobStatus |
| 26 | +from google.cloud import dataproc |
| 27 | +from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus |
27 | 28 |
|
28 | 29 | from airflow.exceptions import ( |
29 | 30 | AirflowException, |
@@ -579,7 +580,7 @@ def test_build_with_flex_migs(self): |
579 | 580 | assert CONFIG_WITH_FLEX_MIG == cluster |
580 | 581 |
|
581 | 582 |
|
582 | | -class TestDataprocClusterCreateOperator(DataprocClusterTestBase): |
| 583 | +class TestDataprocCreateClusterOperator(DataprocClusterTestBase): |
583 | 584 | def test_deprecation_warning(self): |
584 | 585 | with pytest.warns(AirflowProviderDeprecationWarning) as warnings: |
585 | 586 | op = DataprocCreateClusterOperator( |
@@ -883,6 +884,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): |
883 | 884 | assert isinstance(exc.value.trigger, DataprocClusterTrigger) |
884 | 885 | assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME |
885 | 886 |
|
| 887 | + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) |
| 888 | + @mock.patch(DATAPROC_PATH.format("DataprocHook")) |
| 889 | + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) |
| 890 | + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): |
| 891 | + cluster = Cluster( |
| 892 | + cluster_name="test_cluster", |
| 893 | + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), |
| 894 | + ) |
| 895 | + mock_hook.return_value.create_cluster.return_value = cluster |
| 896 | + mock_hook.return_value.get_cluster.return_value = cluster |
| 897 | + operator = DataprocCreateClusterOperator( |
| 898 | + task_id=TASK_ID, |
| 899 | + region=GCP_REGION, |
| 900 | + project_id=GCP_PROJECT, |
| 901 | + cluster_config=CONFIG, |
| 902 | + labels=LABELS, |
| 903 | + cluster_name=CLUSTER_NAME, |
| 904 | + delete_on_error=True, |
| 905 | + metadata=METADATA, |
| 906 | + gcp_conn_id=GCP_CONN_ID, |
| 907 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 908 | + retry=RETRY, |
| 909 | + timeout=TIMEOUT, |
| 910 | + deferrable=True, |
| 911 | + ) |
| 912 | + |
| 913 | + operator.execute(mock.MagicMock()) |
| 914 | + assert not mock_defer.called |
| 915 | + |
| 916 | + mock_hook.assert_called_once_with( |
| 917 | + gcp_conn_id=GCP_CONN_ID, |
| 918 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 919 | + ) |
| 920 | + |
| 921 | + mock_hook.return_value.create_cluster.assert_called_once_with( |
| 922 | + region=GCP_REGION, |
| 923 | + project_id=GCP_PROJECT, |
| 924 | + cluster_config=CONFIG, |
| 925 | + request_id=None, |
| 926 | + labels=LABELS, |
| 927 | + cluster_name=CLUSTER_NAME, |
| 928 | + virtual_cluster_config=None, |
| 929 | + retry=RETRY, |
| 930 | + timeout=TIMEOUT, |
| 931 | + metadata=METADATA, |
| 932 | + ) |
| 933 | + mock_hook.return_value.wait_for_operation.assert_not_called() |
| 934 | + |
886 | 935 |
|
887 | 936 | @pytest.mark.db_test |
888 | 937 | @pytest.mark.need_serialized_dag |
@@ -1100,6 +1149,47 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): |
1100 | 1149 | assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger) |
1101 | 1150 | assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME |
1102 | 1151 |
|
| 1152 | + @mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer")) |
| 1153 | + @mock.patch(DATAPROC_PATH.format("DataprocHook")) |
| 1154 | + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) |
| 1155 | + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): |
| 1156 | + mock_hook.return_value.create_cluster.return_value = None |
| 1157 | + mock_hook.return_value.get_cluster.side_effect = NotFound("test") |
| 1158 | + operator = DataprocDeleteClusterOperator( |
| 1159 | + task_id=TASK_ID, |
| 1160 | + region=GCP_REGION, |
| 1161 | + project_id=GCP_PROJECT, |
| 1162 | + cluster_name=CLUSTER_NAME, |
| 1163 | + request_id=REQUEST_ID, |
| 1164 | + gcp_conn_id=GCP_CONN_ID, |
| 1165 | + retry=RETRY, |
| 1166 | + timeout=TIMEOUT, |
| 1167 | + metadata=METADATA, |
| 1168 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 1169 | + deferrable=True, |
| 1170 | + ) |
| 1171 | + |
| 1172 | + operator.execute(mock.MagicMock()) |
| 1173 | + |
| 1174 | + mock_hook.assert_called_once_with( |
| 1175 | + gcp_conn_id=GCP_CONN_ID, |
| 1176 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 1177 | + ) |
| 1178 | + |
| 1179 | + mock_hook.return_value.delete_cluster.assert_called_once_with( |
| 1180 | + project_id=GCP_PROJECT, |
| 1181 | + region=GCP_REGION, |
| 1182 | + cluster_name=CLUSTER_NAME, |
| 1183 | + cluster_uuid=None, |
| 1184 | + request_id=REQUEST_ID, |
| 1185 | + retry=RETRY, |
| 1186 | + timeout=TIMEOUT, |
| 1187 | + metadata=METADATA, |
| 1188 | + ) |
| 1189 | + |
| 1190 | + mock_hook.return_value.wait_for_operation.assert_not_called() |
| 1191 | + assert not mock_defer.called |
| 1192 | + |
1103 | 1193 |
|
1104 | 1194 | class TestDataprocSubmitJobOperator(DataprocJobTestBase): |
1105 | 1195 | @mock.patch(DATAPROC_PATH.format("DataprocHook")) |
@@ -1240,8 +1330,8 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook): |
1240 | 1330 | assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME |
1241 | 1331 |
|
1242 | 1332 | @mock.patch(DATAPROC_PATH.format("DataprocHook")) |
1243 | | - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer") |
1244 | | - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job") |
| 1333 | + @mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer")) |
| 1334 | + @mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job")) |
1245 | 1335 | def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook): |
1246 | 1336 | mock_submit_job.return_value.reference.job_id = TEST_JOB_ID |
1247 | 1337 | job_status = mock_hook.return_value.get_job.return_value.status |
@@ -1498,6 +1588,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): |
1498 | 1588 | assert isinstance(exc.value.trigger, DataprocClusterTrigger) |
1499 | 1589 | assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME |
1500 | 1590 |
|
| 1591 | + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) |
| 1592 | + @mock.patch(DATAPROC_PATH.format("DataprocHook")) |
| 1593 | + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) |
| 1594 | + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): |
| 1595 | + cluster = Cluster( |
| 1596 | + cluster_name="test_cluster", |
| 1597 | + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), |
| 1598 | + ) |
| 1599 | + mock_hook.return_value.update_cluster.return_value = cluster |
| 1600 | + mock_hook.return_value.get_cluster.return_value = cluster |
| 1601 | + operator = DataprocUpdateClusterOperator( |
| 1602 | + task_id=TASK_ID, |
| 1603 | + region=GCP_REGION, |
| 1604 | + cluster_name=CLUSTER_NAME, |
| 1605 | + cluster=CLUSTER, |
| 1606 | + update_mask=UPDATE_MASK, |
| 1607 | + request_id=REQUEST_ID, |
| 1608 | + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, |
| 1609 | + project_id=GCP_PROJECT, |
| 1610 | + gcp_conn_id=GCP_CONN_ID, |
| 1611 | + retry=RETRY, |
| 1612 | + timeout=TIMEOUT, |
| 1613 | + metadata=METADATA, |
| 1614 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 1615 | + deferrable=True, |
| 1616 | + ) |
| 1617 | + |
| 1618 | + operator.execute(mock.MagicMock()) |
| 1619 | + |
| 1620 | + mock_hook.assert_called_once_with( |
| 1621 | + gcp_conn_id=GCP_CONN_ID, |
| 1622 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 1623 | + ) |
| 1624 | + mock_hook.return_value.update_cluster.assert_called_once_with( |
| 1625 | + project_id=GCP_PROJECT, |
| 1626 | + region=GCP_REGION, |
| 1627 | + cluster_name=CLUSTER_NAME, |
| 1628 | + cluster=CLUSTER, |
| 1629 | + update_mask=UPDATE_MASK, |
| 1630 | + request_id=REQUEST_ID, |
| 1631 | + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, |
| 1632 | + retry=RETRY, |
| 1633 | + timeout=TIMEOUT, |
| 1634 | + metadata=METADATA, |
| 1635 | + ) |
| 1636 | + mock_hook.return_value.wait_for_operation.assert_not_called() |
| 1637 | + assert not mock_defer.called |
| 1638 | + |
1501 | 1639 |
|
1502 | 1640 | @pytest.mark.db_test |
1503 | 1641 | @pytest.mark.need_serialized_dag |
|
0 commit comments