Skip to content

Commit 8dcee5b

Browse files
authored
Add deprecation warnings and raise exception for already deprecated ones (#38673)
1 parent da79f6b commit 8dcee5b

File tree

13 files changed

+864
-161
lines changed

13 files changed

+864
-161
lines changed

airflow/providers/google/cloud/hooks/automl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,37 @@ def delete_dataset(
640640
metadata=metadata,
641641
)
642642
return result
643+
644+
@GoogleBaseHook.fallback_to_default_project_id
645+
def get_dataset(
646+
self,
647+
dataset_id: str,
648+
location: str,
649+
project_id: str,
650+
retry: Retry | _MethodDefault = DEFAULT,
651+
timeout: float | None = None,
652+
metadata: Sequence[tuple[str, str]] = (),
653+
) -> Dataset:
654+
"""
655+
Retrieve the dataset for the given dataset_id.
656+
657+
:param dataset_id: ID of dataset to be retrieved.
658+
:param location: The location of the project.
659+
:param project_id: ID of the Google Cloud project where dataset is located if None then
660+
default project_id is used.
661+
:param retry: A retry object used to retry requests. If `None` is specified, requests will not be
662+
retried.
663+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
664+
`retry` is specified, the timeout applies to each individual attempt.
665+
:param metadata: Additional metadata that is provided to the method.
666+
667+
:return: `google.cloud.automl_v1beta1.types.dataset.Dataset` instance.
668+
"""
669+
client = self.get_conn()
670+
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
671+
return client.get_dataset(
672+
request={"name": name},
673+
retry=retry,
674+
timeout=timeout,
675+
metadata=metadata,
676+
)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import TYPE_CHECKING, Sequence
21+
22+
from google.api_core.client_options import ClientOptions
23+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
24+
from google.cloud.aiplatform_v1 import PredictionServiceClient
25+
26+
from airflow.providers.google.common.consts import CLIENT_INFO
27+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
28+
29+
if TYPE_CHECKING:
30+
from google.api_core.retry import Retry
31+
from google.cloud.aiplatform_v1.types import PredictResponse
32+
33+
34+
class PredictionServiceHook(GoogleBaseHook):
35+
"""Hook for Google Cloud Vertex AI Prediction API."""
36+
37+
def get_prediction_service_client(self, region: str | None = None) -> PredictionServiceClient:
38+
"""
39+
Return PredictionServiceClient object.
40+
41+
:param region: The ID of the Google Cloud region that the service belongs to. Default is None.
42+
43+
:return: `google.cloud.aiplatform_v1.services.prediction_service.client.PredictionServiceClient` instance.
44+
"""
45+
if region and region != "global":
46+
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
47+
else:
48+
client_options = ClientOptions()
49+
50+
return PredictionServiceClient(
51+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
52+
)
53+
54+
@GoogleBaseHook.fallback_to_default_project_id
55+
def predict(
56+
self,
57+
endpoint_id: str,
58+
instances: list[str],
59+
location: str,
60+
project_id: str = PROVIDE_PROJECT_ID,
61+
parameters: dict[str, str] | None = None,
62+
retry: Retry | _MethodDefault = DEFAULT,
63+
timeout: float | None = None,
64+
metadata: Sequence[tuple[str, str]] = (),
65+
) -> PredictResponse:
66+
"""
67+
Perform an online prediction and returns the prediction result in the response.
68+
69+
:param endpoint_id: Name of the endpoint_id requested to serve the prediction.
70+
:param instances: Required. The instances that are the input to the prediction call. A DeployedModel
71+
may have an upper limit on the number of instances it supports per request, and when it is
72+
exceeded the prediction call errors in case of AutoML Models, or, in case of customer created
73+
Models, the behaviour is as documented by that Model.
74+
:param parameters: Additional domain-specific parameters, any string must be up to 25000 characters long.
75+
:param project_id: ID of the Google Cloud project where model is located if None then
76+
default project_id is used.
77+
:param location: The location of the project.
78+
:param retry: A retry object used to retry requests. If `None` is specified, requests will not be
79+
retried.
80+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
81+
`retry` is specified, the timeout applies to each individual attempt.
82+
:param metadata: Additional metadata that is provided to the method.
83+
"""
84+
client = self.get_prediction_service_client(location)
85+
endpoint = f"projects/{project_id}/locations/{location}/endpoints/{endpoint_id}"
86+
return client.predict(
87+
request={"endpoint": endpoint, "instances": instances, "parameters": parameters},
88+
retry=retry,
89+
timeout=timeout,
90+
metadata=metadata,
91+
)

0 commit comments

Comments
 (0)