1717from __future__ import annotations
1818
1919from unittest import mock
20+ from unittest .mock import MagicMock
2021
2122from google .api_core .gapic_v1 .method import DEFAULT
2223from google .api_core .retry import Retry
@@ -783,7 +784,12 @@ class TestVertexAICreateAutoMLTabularTrainingJobOperator:
783784 @mock .patch ("google.cloud.aiplatform.datasets.TabularDataset" )
784785 @mock .patch (VERTEX_AI_PATH .format ("auto_ml.AutoMLHook" ))
785786 def test_execute (self , mock_hook , mock_dataset ):
786- mock_hook .return_value .create_auto_ml_tabular_training_job .return_value = (None , "training_id" )
787+ mock_hook .return_value = MagicMock (
788+ ** {
789+ "create_auto_ml_tabular_training_job.return_value" : (None , "training_id" ),
790+ "get_credentials_and_project_id.return_value" : ("creds" , "project_id" ),
791+ }
792+ )
787793 op = CreateAutoMLTabularTrainingJobOperator (
788794 task_id = TASK_ID ,
789795 gcp_conn_id = GCP_CONN_ID ,
@@ -798,7 +804,9 @@ def test_execute(self, mock_hook, mock_dataset):
798804 )
799805 op .execute (context = {"ti" : mock .MagicMock ()})
800806 mock_hook .assert_called_once_with (gcp_conn_id = GCP_CONN_ID , impersonation_chain = IMPERSONATION_CHAIN )
801- mock_dataset .assert_called_once_with (dataset_name = TEST_DATASET_ID )
807+ mock_dataset .assert_called_once_with (
808+ dataset_name = TEST_DATASET_ID , project = GCP_PROJECT , credentials = "creds"
809+ )
802810 mock_hook .return_value .create_auto_ml_tabular_training_job .assert_called_once_with (
803811 project_id = GCP_PROJECT ,
804812 region = GCP_LOCATION ,
0 commit comments