2222import ast
2323import warnings
2424from functools import cached_property
25- from typing import TYPE_CHECKING , Sequence , Tuple
25+ from typing import TYPE_CHECKING , Sequence , Tuple , cast
2626
2727from google .api_core .gapic_v1 .method import DEFAULT , _MethodDefault
2828from google .cloud .automl_v1beta1 import (
@@ -280,17 +280,22 @@ def hook(self) -> CloudAutoMLHook | PredictionServiceHook:
280280 impersonation_chain = self .impersonation_chain ,
281281 )
282282
283+ @cached_property
284+ def model (self ) -> Model | None :
285+ if self .model_id :
286+ hook = cast (CloudAutoMLHook , self .hook )
287+ return hook .get_model (
288+ model_id = self .model_id ,
289+ location = self .location ,
290+ project_id = self .project_id ,
291+ retry = self .retry ,
292+ timeout = self .timeout ,
293+ metadata = self .metadata ,
294+ )
295+ return None
296+
283297 def _check_model_type (self ):
284- hook = self .hook
285- model = hook .get_model (
286- model_id = self .model_id ,
287- location = self .location ,
288- project_id = self .project_id ,
289- retry = self .retry ,
290- timeout = self .timeout ,
291- metadata = self .metadata ,
292- )
293- if not hasattr (model , "translation_model_metadata" ):
298+ if not hasattr (self .model , "translation_model_metadata" ):
294299 raise AirflowException (
295300 "AutoMLPredictOperator for text, image, and video prediction has been deprecated. "
296301 "Please use endpoint_id param instead of model_id param."
@@ -329,11 +334,13 @@ def execute(self, context: Context):
329334 )
330335
331336 project_id = self .project_id or hook .project_id
332- if project_id and self .model_id :
337+ dataset_id : str | None = self .model .dataset_id if self .model else None
338+ if project_id and self .model_id and dataset_id :
333339 TranslationLegacyModelPredictLink .persist (
334340 context = context ,
335341 task_instance = self ,
336342 model_id = self .model_id ,
343+ dataset_id = dataset_id ,
337344 project_id = project_id ,
338345 )
339346 return PredictResponse .to_dict (result )
@@ -431,12 +438,16 @@ def __init__(
431438 self .input_config = input_config
432439 self .output_config = output_config
433440
434- def execute (self , context : Context ):
435- hook = CloudAutoMLHook (
441+ @cached_property
442+ def hook (self ) -> CloudAutoMLHook :
443+ return CloudAutoMLHook (
436444 gcp_conn_id = self .gcp_conn_id ,
437445 impersonation_chain = self .impersonation_chain ,
438446 )
439- self .model : Model = hook .get_model (
447+
448+ @cached_property
449+ def model (self ) -> Model :
450+ return self .hook .get_model (
440451 model_id = self .model_id ,
441452 location = self .location ,
442453 project_id = self .project_id ,
@@ -445,6 +456,7 @@ def execute(self, context: Context):
445456 metadata = self .metadata ,
446457 )
447458
459+ def execute (self , context : Context ):
448460 if not hasattr (self .model , "translation_model_metadata" ):
449461 _raise_exception_for_deprecated_operator (
450462 self .__class__ .__name__ ,
@@ -456,7 +468,7 @@ def execute(self, context: Context):
456468 ],
457469 )
458470 self .log .info ("Fetch batch prediction." )
459- operation = hook .batch_predict (
471+ operation = self . hook .batch_predict (
460472 model_id = self .model_id ,
461473 input_config = self .input_config ,
462474 output_config = self .output_config ,
@@ -467,16 +479,17 @@ def execute(self, context: Context):
467479 timeout = self .timeout ,
468480 metadata = self .metadata ,
469481 )
470- operation_result = hook .wait_for_operation (timeout = self .timeout , operation = operation )
482+ operation_result = self . hook .wait_for_operation (timeout = self .timeout , operation = operation )
471483 result = BatchPredictResult .to_dict (operation_result )
472484 self .log .info ("Batch prediction is ready." )
473- project_id = self .project_id or hook .project_id
485+ project_id = self .project_id or self . hook .project_id
474486 if project_id :
475487 TranslationLegacyModelPredictLink .persist (
476488 context = context ,
477489 task_instance = self ,
478490 model_id = self .model_id ,
479491 project_id = project_id ,
492+ dataset_id = self .model .dataset_id ,
480493 )
481494 return result
482495
0 commit comments