|
21 | 21 |
|
22 | 22 | from typing import TYPE_CHECKING, Sequence |
23 | 23 |
|
24 | | -from google.cloud.aiplatform_v1beta1 import types as types_v1beta1 |
25 | | - |
26 | 24 | from airflow.exceptions import AirflowProviderDeprecationWarning |
27 | 25 | from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook |
28 | 26 | from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator |
@@ -742,8 +740,6 @@ def execute(self, context: Context): |
742 | 740 | self.xcom_push(context, key="total_tokens", value=response.total_tokens) |
743 | 741 | self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters) |
744 | 742 |
|
745 | | - return types_v1beta1.CountTokensResponse.to_dict(response) |
746 | | - |
747 | 743 |
|
748 | 744 | class RunEvaluationOperator(GoogleCloudBaseOperator): |
749 | 745 | """ |
@@ -842,3 +838,155 @@ def execute(self, context: Context): |
842 | 838 | ) |
843 | 839 |
|
844 | 840 | return response.summary_metrics |
| 841 | + |
| 842 | + |
| 843 | +class CreateCachedContentOperator(GoogleCloudBaseOperator): |
| 844 | + """ |
| 845 | + Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts. |
| 846 | +
|
| 847 | + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. |
| 848 | + :param location: Required. The ID of the Google Cloud location that the service belongs to. |
| 849 | + :param model_name: Required. The name of the publisher model to use for cached content. |
| 850 | + :param system_instruction: Developer set system instruction. |
| 851 | + :param contents: The content to cache. |
| 852 | + :param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL. |
| 853 | + Defaults to one hour. |
| 854 | + :param display_name: The user-generated meaningful display name of the cached content |
| 855 | + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. |
| 856 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 857 | + credentials, or chained list of accounts required to get the access_token |
| 858 | + of the last account in the list, which will be impersonated in the request. |
| 859 | + If set as a string, the account must grant the originating account |
| 860 | + the Service Account Token Creator IAM role. |
| 861 | + If set as a sequence, the identities from the list must grant |
| 862 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 863 | + account from the list granting this role to the originating account (templated). |
| 864 | + """ |
| 865 | + |
| 866 | + template_fields = ( |
| 867 | + "location", |
| 868 | + "project_id", |
| 869 | + "impersonation_chain", |
| 870 | + "model_name", |
| 871 | + "contents", |
| 872 | + "system_instruction", |
| 873 | + ) |
| 874 | + |
| 875 | + def __init__( |
| 876 | + self, |
| 877 | + *, |
| 878 | + project_id: str, |
| 879 | + location: str, |
| 880 | + model_name: str, |
| 881 | + system_instruction: str | None = None, |
| 882 | + contents: list | None = None, |
| 883 | + ttl_hours: float = 1, |
| 884 | + display_name: str | None = None, |
| 885 | + gcp_conn_id: str = "google_cloud_default", |
| 886 | + impersonation_chain: str | Sequence[str] | None = None, |
| 887 | + **kwargs, |
| 888 | + ) -> None: |
| 889 | + super().__init__(**kwargs) |
| 890 | + |
| 891 | + self.project_id = project_id |
| 892 | + self.location = location |
| 893 | + self.model_name = model_name |
| 894 | + self.system_instruction = system_instruction |
| 895 | + self.contents = contents |
| 896 | + self.ttl_hours = ttl_hours |
| 897 | + self.display_name = display_name |
| 898 | + self.gcp_conn_id = gcp_conn_id |
| 899 | + self.impersonation_chain = impersonation_chain |
| 900 | + |
| 901 | + def execute(self, context: Context): |
| 902 | + self.hook = GenerativeModelHook( |
| 903 | + gcp_conn_id=self.gcp_conn_id, |
| 904 | + impersonation_chain=self.impersonation_chain, |
| 905 | + ) |
| 906 | + |
| 907 | + cached_content_name = self.hook.create_cached_content( |
| 908 | + project_id=self.project_id, |
| 909 | + location=self.location, |
| 910 | + model_name=self.model_name, |
| 911 | + system_instruction=self.system_instruction, |
| 912 | + contents=self.contents, |
| 913 | + ttl_hours=self.ttl_hours, |
| 914 | + display_name=self.display_name, |
| 915 | + ) |
| 916 | + |
| 917 | + self.log.info("Cached Content Name: %s", cached_content_name) |
| 918 | + |
| 919 | + return cached_content_name |
| 920 | + |
| 921 | + |
| 922 | +class GenerateFromCachedContentOperator(GoogleCloudBaseOperator): |
| 923 | + """ |
| 924 | + Generate a response from CachedContent. |
| 925 | +
|
| 926 | + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. |
| 927 | + :param location: Required. The ID of the Google Cloud location that the service belongs to. |
| 928 | + :param cached_content_name: Required. The name of the cached content resource. |
| 929 | + :param contents: Required. The multi-part content of a message that a user or a program |
| 930 | + gives to the generative model, in order to elicit a specific response. |
| 931 | + :param generation_config: Optional. Generation configuration settings. |
| 932 | + :param safety_settings: Optional. Per request settings for blocking unsafe content. |
| 933 | + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. |
| 934 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 935 | + credentials, or chained list of accounts required to get the access_token |
| 936 | + of the last account in the list, which will be impersonated in the request. |
| 937 | + If set as a string, the account must grant the originating account |
| 938 | + the Service Account Token Creator IAM role. |
| 939 | + If set as a sequence, the identities from the list must grant |
| 940 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 941 | + account from the list granting this role to the originating account (templated). |
| 942 | + """ |
| 943 | + |
| 944 | + template_fields = ( |
| 945 | + "location", |
| 946 | + "project_id", |
| 947 | + "impersonation_chain", |
| 948 | + "cached_content_name", |
| 949 | + "contents", |
| 950 | + ) |
| 951 | + |
| 952 | + def __init__( |
| 953 | + self, |
| 954 | + *, |
| 955 | + project_id: str, |
| 956 | + location: str, |
| 957 | + cached_content_name: str, |
| 958 | + contents: list, |
| 959 | + generation_config: dict | None = None, |
| 960 | + safety_settings: dict | None = None, |
| 961 | + gcp_conn_id: str = "google_cloud_default", |
| 962 | + impersonation_chain: str | Sequence[str] | None = None, |
| 963 | + **kwargs, |
| 964 | + ) -> None: |
| 965 | + super().__init__(**kwargs) |
| 966 | + |
| 967 | + self.project_id = project_id |
| 968 | + self.location = location |
| 969 | + self.cached_content_name = cached_content_name |
| 970 | + self.contents = contents |
| 971 | + self.generation_config = generation_config |
| 972 | + self.safety_settings = safety_settings |
| 973 | + self.gcp_conn_id = gcp_conn_id |
| 974 | + self.impersonation_chain = impersonation_chain |
| 975 | + |
| 976 | + def execute(self, context: Context): |
| 977 | + self.hook = GenerativeModelHook( |
| 978 | + gcp_conn_id=self.gcp_conn_id, |
| 979 | + impersonation_chain=self.impersonation_chain, |
| 980 | + ) |
| 981 | + cached_content_text = self.hook.generate_from_cached_content( |
| 982 | + project_id=self.project_id, |
| 983 | + location=self.location, |
| 984 | + cached_content_name=self.cached_content_name, |
| 985 | + contents=self.contents, |
| 986 | + generation_config=self.generation_config, |
| 987 | + safety_settings=self.safety_settings, |
| 988 | + ) |
| 989 | + |
| 990 | + self.log.info("Cached Content Response: %s", cached_content_text) |
| 991 | + |
| 992 | + return cached_content_text |
0 commit comments