Skip to content

Commit f1f9201

Browse files
authored
Create Operators for Google Cloud Vertex AI Context Caching (#43008)
* Fix merge conflicts * Fix documentation. * Update return variables.
1 parent c9c4ca5 commit f1f9201

File tree

8 files changed

+442
-9
lines changed

8 files changed

+442
-9
lines changed

docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,26 @@ The operator returns the evaluation summary metrics in :ref:`XCom <concepts:xcom
645645
:start-after: [START how_to_cloud_vertex_ai_run_evaluation_operator]
646646
:end-before: [END how_to_cloud_vertex_ai_run_evaluation_operator]
647647

648+
To create cached content you can use
649+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CreateCachedContentOperator`.
650+
The operator returns the cached content resource name in :ref:`XCom <concepts:xcom>` under ``return_value`` key.
651+
652+
.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
653+
:language: python
654+
:dedent: 4
655+
:start-after: [START how_to_cloud_vertex_ai_create_cached_content_operator]
656+
:end-before: [END how_to_cloud_vertex_ai_create_cached_content_operator]
657+
658+
To generate a response from cached content you can use
659+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateFromCachedContentOperator`.
660+
The operator returns the cached content response in :ref:`XCom <concepts:xcom>` under ``return_value`` key.
661+
662+
.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
663+
:language: python
664+
:dedent: 4
665+
:start-after: [START how_to_cloud_vertex_ai_generate_from_cached_content_operator]
666+
:end-before: [END how_to_cloud_vertex_ai_generate_from_cached_content_operator]
667+
648668
Reference
649669
^^^^^^^^^
650670

generated/provider_dependencies.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@
640640
"google-api-python-client>=2.0.2",
641641
"google-auth-httplib2>=0.0.1",
642642
"google-auth>=2.29.0",
643-
"google-cloud-aiplatform>=1.63.0",
643+
"google-cloud-aiplatform>=1.70.0",
644644
"google-cloud-automl>=2.12.0",
645645
"google-cloud-batch>=0.13.0",
646646
"google-cloud-bigquery-datatransfer>=3.13.0",

providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from __future__ import annotations
2121

2222
import time
23+
from datetime import timedelta
2324
from typing import TYPE_CHECKING, Sequence
2425

2526
import vertexai
2627
from vertexai.generative_models import GenerativeModel, Part
2728
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
29+
from vertexai.preview.caching import CachedContent
2830
from vertexai.preview.evaluation import EvalResult, EvalTask
31+
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
2932
from vertexai.preview.tuning import sft
3033

3134
from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -95,6 +98,16 @@ def get_eval_task(
9598
)
9699
return eval_task
97100

101+
def get_cached_context_model(
102+
self,
103+
cached_content_name: str,
104+
) -> preview_generative_model:
105+
"""Return a Generative Model with Cached Context."""
106+
cached_content = CachedContent(cached_content_name=cached_content_name)
107+
108+
cached_context_model = preview_generative_model.from_cached_content(cached_content)
109+
return cached_context_model
110+
98111
@deprecated(
99112
planned_removal_date="January 01, 2025",
100113
use_instead="Part objects included in contents parameter of "
@@ -528,3 +541,69 @@ def run_evaluation(
528541
)
529542

530543
return eval_result
544+
545+
def create_cached_content(
546+
self,
547+
model_name: str,
548+
location: str,
549+
ttl_hours: float = 1,
550+
system_instruction: str | None = None,
551+
contents: list | None = None,
552+
display_name: str | None = None,
553+
project_id: str = PROVIDE_PROJECT_ID,
554+
) -> str:
555+
"""
556+
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
557+
558+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
559+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
560+
:param model_name: Required. The name of the publisher model to use for cached content.
561+
:param system_instruction: Developer set system instruction.
562+
:param contents: The content to cache.
563+
:param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL.
564+
Defaults to one hour.
565+
:param display_name: The user-generated meaningful display name of the cached content
566+
"""
567+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
568+
569+
response = CachedContent.create(
570+
model_name=model_name,
571+
system_instruction=system_instruction,
572+
contents=contents,
573+
ttl=timedelta(hours=ttl_hours),
574+
display_name=display_name,
575+
)
576+
577+
return response.name
578+
579+
def generate_from_cached_content(
580+
self,
581+
location: str,
582+
cached_content_name: str,
583+
contents: list,
584+
generation_config: dict | None = None,
585+
safety_settings: dict | None = None,
586+
project_id: str = PROVIDE_PROJECT_ID,
587+
) -> str:
588+
"""
589+
Generate a response from CachedContent.
590+
591+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
592+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
593+
:param cached_content_name: Required. The name of the cached content resource.
594+
:param contents: Required. The multi-part content of a message that a user or a program
595+
gives to the generative model, in order to elicit a specific response.
596+
:param generation_config: Optional. Generation configuration settings.
597+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
598+
"""
599+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
600+
601+
cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)
602+
603+
response = cached_context_model.generate_content(
604+
contents=contents,
605+
generation_config=generation_config,
606+
safety_settings=safety_settings,
607+
)
608+
609+
return response.text

providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
from typing import TYPE_CHECKING, Sequence
2323

24-
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
25-
2624
from airflow.exceptions import AirflowProviderDeprecationWarning
2725
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
2826
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
@@ -742,8 +740,6 @@ def execute(self, context: Context):
742740
self.xcom_push(context, key="total_tokens", value=response.total_tokens)
743741
self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
744742

745-
return types_v1beta1.CountTokensResponse.to_dict(response)
746-
747743

748744
class RunEvaluationOperator(GoogleCloudBaseOperator):
749745
"""
@@ -842,3 +838,155 @@ def execute(self, context: Context):
842838
)
843839

844840
return response.summary_metrics
841+
842+
843+
class CreateCachedContentOperator(GoogleCloudBaseOperator):
844+
"""
845+
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
846+
847+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
848+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
849+
:param model_name: Required. The name of the publisher model to use for cached content.
850+
:param system_instruction: Developer set system instruction.
851+
:param contents: The content to cache.
852+
:param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL.
853+
Defaults to one hour.
854+
:param display_name: The user-generated meaningful display name of the cached content
855+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
856+
:param impersonation_chain: Optional service account to impersonate using short-term
857+
credentials, or chained list of accounts required to get the access_token
858+
of the last account in the list, which will be impersonated in the request.
859+
If set as a string, the account must grant the originating account
860+
the Service Account Token Creator IAM role.
861+
If set as a sequence, the identities from the list must grant
862+
Service Account Token Creator IAM role to the directly preceding identity, with first
863+
account from the list granting this role to the originating account (templated).
864+
"""
865+
866+
template_fields = (
867+
"location",
868+
"project_id",
869+
"impersonation_chain",
870+
"model_name",
871+
"contents",
872+
"system_instruction",
873+
)
874+
875+
def __init__(
876+
self,
877+
*,
878+
project_id: str,
879+
location: str,
880+
model_name: str,
881+
system_instruction: str | None = None,
882+
contents: list | None = None,
883+
ttl_hours: float = 1,
884+
display_name: str | None = None,
885+
gcp_conn_id: str = "google_cloud_default",
886+
impersonation_chain: str | Sequence[str] | None = None,
887+
**kwargs,
888+
) -> None:
889+
super().__init__(**kwargs)
890+
891+
self.project_id = project_id
892+
self.location = location
893+
self.model_name = model_name
894+
self.system_instruction = system_instruction
895+
self.contents = contents
896+
self.ttl_hours = ttl_hours
897+
self.display_name = display_name
898+
self.gcp_conn_id = gcp_conn_id
899+
self.impersonation_chain = impersonation_chain
900+
901+
def execute(self, context: Context):
902+
self.hook = GenerativeModelHook(
903+
gcp_conn_id=self.gcp_conn_id,
904+
impersonation_chain=self.impersonation_chain,
905+
)
906+
907+
cached_content_name = self.hook.create_cached_content(
908+
project_id=self.project_id,
909+
location=self.location,
910+
model_name=self.model_name,
911+
system_instruction=self.system_instruction,
912+
contents=self.contents,
913+
ttl_hours=self.ttl_hours,
914+
display_name=self.display_name,
915+
)
916+
917+
self.log.info("Cached Content Name: %s", cached_content_name)
918+
919+
return cached_content_name
920+
921+
922+
class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
923+
"""
924+
Generate a response from CachedContent.
925+
926+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
927+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
928+
:param cached_content_name: Required. The name of the cached content resource.
929+
:param contents: Required. The multi-part content of a message that a user or a program
930+
gives to the generative model, in order to elicit a specific response.
931+
:param generation_config: Optional. Generation configuration settings.
932+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
933+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
934+
:param impersonation_chain: Optional service account to impersonate using short-term
935+
credentials, or chained list of accounts required to get the access_token
936+
of the last account in the list, which will be impersonated in the request.
937+
If set as a string, the account must grant the originating account
938+
the Service Account Token Creator IAM role.
939+
If set as a sequence, the identities from the list must grant
940+
Service Account Token Creator IAM role to the directly preceding identity, with first
941+
account from the list granting this role to the originating account (templated).
942+
"""
943+
944+
template_fields = (
945+
"location",
946+
"project_id",
947+
"impersonation_chain",
948+
"cached_content_name",
949+
"contents",
950+
)
951+
952+
def __init__(
953+
self,
954+
*,
955+
project_id: str,
956+
location: str,
957+
cached_content_name: str,
958+
contents: list,
959+
generation_config: dict | None = None,
960+
safety_settings: dict | None = None,
961+
gcp_conn_id: str = "google_cloud_default",
962+
impersonation_chain: str | Sequence[str] | None = None,
963+
**kwargs,
964+
) -> None:
965+
super().__init__(**kwargs)
966+
967+
self.project_id = project_id
968+
self.location = location
969+
self.cached_content_name = cached_content_name
970+
self.contents = contents
971+
self.generation_config = generation_config
972+
self.safety_settings = safety_settings
973+
self.gcp_conn_id = gcp_conn_id
974+
self.impersonation_chain = impersonation_chain
975+
976+
def execute(self, context: Context):
977+
self.hook = GenerativeModelHook(
978+
gcp_conn_id=self.gcp_conn_id,
979+
impersonation_chain=self.impersonation_chain,
980+
)
981+
cached_content_text = self.hook.generate_from_cached_content(
982+
project_id=self.project_id,
983+
location=self.location,
984+
cached_content_name=self.cached_content_name,
985+
contents=self.contents,
986+
generation_config=self.generation_config,
987+
safety_settings=self.safety_settings,
988+
)
989+
990+
self.log.info("Cached Content Response: %s", cached_content_text)
991+
992+
return cached_content_text

providers/src/airflow/providers/google/provider.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ dependencies:
114114
- google-api-python-client>=2.0.2
115115
- google-auth>=2.29.0
116116
- google-auth-httplib2>=0.0.1
117-
- google-cloud-aiplatform>=1.63.0
117+
- google-cloud-aiplatform>=1.70.0
118118
- google-cloud-automl>=2.12.0
119119
# Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0
120120
- google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.*

providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
# For no Pydantic environment, we need to skip the tests
2929
pytest.importorskip("google.cloud.aiplatform_v1")
30-
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding
30+
from datetime import timedelta
31+
32+
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Part, Tool, grounding
3133
from vertexai.preview.evaluation import MetricPromptTemplateExamples
3234

3335
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
@@ -106,6 +108,27 @@
106108
TEST_EXPERIMENT_RUN_NAME = "eval-experiment-airflow-operator-run"
107109
TEST_PROMPT_TEMPLATE = "{instruction}. Article: {context}. Summary:"
108110

111+
TEST_CACHED_CONTENT_NAME = "test-example-cache"
112+
TEST_CACHED_CONTENT_PROMPT = ["What are these papers about?"]
113+
TEST_CACHED_MODEL = "gemini-1.5-pro-002"
114+
TEST_CACHED_SYSTEM_INSTRUCTION = """
115+
You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
116+
Now look at these research papers, and answer the following questions.
117+
"""
118+
119+
TEST_CACHED_CONTENTS = [
120+
Part.from_uri(
121+
"gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
122+
mime_type="application/pdf",
123+
),
124+
Part.from_uri(
125+
"gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
126+
mime_type="application/pdf",
127+
),
128+
]
129+
TEST_CACHED_TTL = 1
130+
TEST_CACHED_DISPLAY_NAME = "test-example-cache"
131+
109132
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
110133
GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.generative_model.{}"
111134

@@ -299,3 +322,38 @@ def test_run_evaluation(self, mock_eval_task, mock_model) -> None:
299322
prompt_template=TEST_PROMPT_TEMPLATE,
300323
experiment_run_name=TEST_EXPERIMENT_RUN_NAME,
301324
)
325+
326+
@mock.patch("vertexai.preview.caching.CachedContent.create")
327+
def test_create_cached_content(self, mock_cached_content_create) -> None:
328+
self.hook.create_cached_content(
329+
project_id=GCP_PROJECT,
330+
location=GCP_LOCATION,
331+
model_name=TEST_CACHED_MODEL,
332+
system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION,
333+
contents=TEST_CACHED_CONTENTS,
334+
ttl_hours=TEST_CACHED_TTL,
335+
display_name=TEST_CACHED_DISPLAY_NAME,
336+
)
337+
338+
mock_cached_content_create.assert_called_once_with(
339+
model_name=TEST_CACHED_MODEL,
340+
system_instruction=TEST_CACHED_SYSTEM_INSTRUCTION,
341+
contents=TEST_CACHED_CONTENTS,
342+
ttl=timedelta(hours=TEST_CACHED_TTL),
343+
display_name=TEST_CACHED_DISPLAY_NAME,
344+
)
345+
346+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_cached_context_model"))
347+
def test_generate_from_cached_content(self, mock_cached_context_model) -> None:
348+
self.hook.generate_from_cached_content(
349+
project_id=GCP_PROJECT,
350+
location=GCP_LOCATION,
351+
cached_content_name=TEST_CACHED_CONTENT_NAME,
352+
contents=TEST_CACHED_CONTENT_PROMPT,
353+
)
354+
355+
mock_cached_context_model.return_value.generate_content.assert_called_once_with(
356+
contents=TEST_CACHED_CONTENT_PROMPT,
357+
generation_config=None,
358+
safety_settings=None,
359+
)

0 commit comments

Comments
 (0)