Skip to content

Commit aeb7e90

Browse files
olegkachur-eOleg Kachur
andauthored
Fix TestTranslationLegacyModelPredictLink dataset_id error (#42463)
- Add dataset_id parameter to let TestTranslationLegacyModelPredictLink work with the translation model. Co-authored-by: Oleg Kachur <kachur@google.com>
1 parent e286bd7 commit aeb7e90

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

providers/src/airflow/providers/google/cloud/links/translate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,14 @@ def persist(
167167
task_instance,
168168
model_id: str,
169169
project_id: str,
170+
dataset_id: str,
170171
):
171172
task_instance.xcom_push(
172173
context,
173174
key=TranslationLegacyModelPredictLink.key,
174175
value={
175176
"location": task_instance.location,
176-
"dataset_id": task_instance.model.dataset_id,
177+
"dataset_id": dataset_id,
177178
"model_id": model_id,
178179
"project_id": project_id,
179180
},

providers/src/airflow/providers/google/cloud/operators/automl.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import ast
2323
import warnings
2424
from functools import cached_property
25-
from typing import TYPE_CHECKING, Sequence, Tuple
25+
from typing import TYPE_CHECKING, Sequence, Tuple, cast
2626

2727
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2828
from google.cloud.automl_v1beta1 import (
@@ -280,17 +280,22 @@ def hook(self) -> CloudAutoMLHook | PredictionServiceHook:
280280
impersonation_chain=self.impersonation_chain,
281281
)
282282

283+
@cached_property
284+
def model(self) -> Model | None:
285+
if self.model_id:
286+
hook = cast(CloudAutoMLHook, self.hook)
287+
return hook.get_model(
288+
model_id=self.model_id,
289+
location=self.location,
290+
project_id=self.project_id,
291+
retry=self.retry,
292+
timeout=self.timeout,
293+
metadata=self.metadata,
294+
)
295+
return None
296+
283297
def _check_model_type(self):
284-
hook = self.hook
285-
model = hook.get_model(
286-
model_id=self.model_id,
287-
location=self.location,
288-
project_id=self.project_id,
289-
retry=self.retry,
290-
timeout=self.timeout,
291-
metadata=self.metadata,
292-
)
293-
if not hasattr(model, "translation_model_metadata"):
298+
if not hasattr(self.model, "translation_model_metadata"):
294299
raise AirflowException(
295300
"AutoMLPredictOperator for text, image, and video prediction has been deprecated. "
296301
"Please use endpoint_id param instead of model_id param."
@@ -329,11 +334,13 @@ def execute(self, context: Context):
329334
)
330335

331336
project_id = self.project_id or hook.project_id
332-
if project_id and self.model_id:
337+
dataset_id: str | None = self.model.dataset_id if self.model else None
338+
if project_id and self.model_id and dataset_id:
333339
TranslationLegacyModelPredictLink.persist(
334340
context=context,
335341
task_instance=self,
336342
model_id=self.model_id,
343+
dataset_id=dataset_id,
337344
project_id=project_id,
338345
)
339346
return PredictResponse.to_dict(result)
@@ -431,12 +438,16 @@ def __init__(
431438
self.input_config = input_config
432439
self.output_config = output_config
433440

434-
def execute(self, context: Context):
435-
hook = CloudAutoMLHook(
441+
@cached_property
442+
def hook(self) -> CloudAutoMLHook:
443+
return CloudAutoMLHook(
436444
gcp_conn_id=self.gcp_conn_id,
437445
impersonation_chain=self.impersonation_chain,
438446
)
439-
self.model: Model = hook.get_model(
447+
448+
@cached_property
449+
def model(self) -> Model:
450+
return self.hook.get_model(
440451
model_id=self.model_id,
441452
location=self.location,
442453
project_id=self.project_id,
@@ -445,6 +456,7 @@ def execute(self, context: Context):
445456
metadata=self.metadata,
446457
)
447458

459+
def execute(self, context: Context):
448460
if not hasattr(self.model, "translation_model_metadata"):
449461
_raise_exception_for_deprecated_operator(
450462
self.__class__.__name__,
@@ -456,7 +468,7 @@ def execute(self, context: Context):
456468
],
457469
)
458470
self.log.info("Fetch batch prediction.")
459-
operation = hook.batch_predict(
471+
operation = self.hook.batch_predict(
460472
model_id=self.model_id,
461473
input_config=self.input_config,
462474
output_config=self.output_config,
@@ -467,16 +479,17 @@ def execute(self, context: Context):
467479
timeout=self.timeout,
468480
metadata=self.metadata,
469481
)
470-
operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
482+
operation_result = self.hook.wait_for_operation(timeout=self.timeout, operation=operation)
471483
result = BatchPredictResult.to_dict(operation_result)
472484
self.log.info("Batch prediction is ready.")
473-
project_id = self.project_id or hook.project_id
485+
project_id = self.project_id or self.hook.project_id
474486
if project_id:
475487
TranslationLegacyModelPredictLink.persist(
476488
context=context,
477489
task_instance=self,
478490
model_id=self.model_id,
479491
project_id=project_id,
492+
dataset_id=self.model.dataset_id,
480493
)
481494
return result
482495

providers/tests/google/cloud/links/test_translate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def test_get_link(self, create_task_instance_of_operator, session):
161161
ti.task.model = Model(dataset_id=DATASET, display_name=MODEL)
162162
session.add(ti)
163163
session.commit()
164-
link.persist(context={"ti": ti}, task_instance=ti.task, model_id=MODEL, project_id=GCP_PROJECT_ID)
164+
link.persist(
165+
context={"ti": ti},
166+
task_instance=ti.task,
167+
model_id=MODEL,
168+
project_id=GCP_PROJECT_ID,
169+
dataset_id=DATASET,
170+
)
165171
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
166172
assert actual_url == expected_url

providers/tests/google/cloud/operators/test_automl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_execute(self, mock_hook, mock_link_persist):
147147
mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
148148
mock_hook.return_value.extract_object_id = extract_object_id
149149
mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult()
150+
mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL)
150151
mock_context = {"ti": mock.MagicMock()}
151152
with pytest.warns(AirflowProviderDeprecationWarning):
152153
op = AutoMLBatchPredictOperator(
@@ -175,6 +176,7 @@ def test_execute(self, mock_hook, mock_link_persist):
175176
task_instance=op,
176177
model_id=MODEL_ID,
177178
project_id=GCP_PROJECT_ID,
179+
dataset_id=DATASET_ID,
178180
)
179181

180182
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -243,6 +245,7 @@ class TestAutoMLPredictOperator:
243245
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
244246
def test_execute(self, mock_hook, mock_link_persist):
245247
mock_hook.return_value.predict.return_value = PredictResponse()
248+
mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL)
246249
mock_context = {"ti": mock.MagicMock()}
247250
op = AutoMLPredictOperator(
248251
model_id=MODEL_ID,
@@ -268,6 +271,7 @@ def test_execute(self, mock_hook, mock_link_persist):
268271
task_instance=op,
269272
model_id=MODEL_ID,
270273
project_id=GCP_PROJECT_ID,
274+
dataset_id=DATASET_ID,
271275
)
272276

273277
@pytest.mark.db_test

0 commit comments

Comments
 (0)