Skip to content

Commit 640c0b6

Browse files
authored
Create CustomJob and Datasets operators for Vertex AI service (#20077)
1 parent 2874d9f commit 640c0b6

File tree

18 files changed

+6736
-1
lines changed

18 files changed

+6736
-1
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""
20+
Example Airflow DAG that demonstrates operators for the Google Vertex AI service in the Google
21+
Cloud Platform.
22+
23+
This DAG relies on the following OS environment variables:
24+
25+
* GCP_VERTEX_AI_BUCKET - Google Cloud Storage bucket where the model will be saved
26+
after training process was finished.
27+
* CUSTOM_CONTAINER_URI - path to container with model.
28+
* PYTHON_PACKAGE_GSC_URI - path to test model in archive.
29+
* LOCAL_TRAINING_SCRIPT_PATH - path to local training script.
30+
* DATASET_ID - ID of dataset which will be used in training process.
31+
"""
32+
import os
33+
from datetime import datetime
34+
from uuid import uuid4
35+
36+
from airflow import models
37+
from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
38+
CreateCustomContainerTrainingJobOperator,
39+
CreateCustomPythonPackageTrainingJobOperator,
40+
CreateCustomTrainingJobOperator,
41+
DeleteCustomTrainingJobOperator,
42+
ListCustomTrainingJobOperator,
43+
)
44+
from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
45+
CreateDatasetOperator,
46+
DeleteDatasetOperator,
47+
ExportDataOperator,
48+
GetDatasetOperator,
49+
ImportDataOperator,
50+
ListDatasetsOperator,
51+
UpdateDatasetOperator,
52+
)
53+
54+
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
55+
REGION = os.environ.get("GCP_LOCATION", "us-central1")
56+
BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests")
57+
58+
STAGING_BUCKET = f"gs://{BUCKET}"
59+
DISPLAY_NAME = str(uuid4()) # Create random display name
60+
CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest"
61+
CUSTOM_CONTAINER_URI = os.environ.get("CUSTOM_CONTAINER_URI", "path_to_container_with_model")
62+
MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest"
63+
REPLICA_COUNT = 1
64+
MACHINE_TYPE = "n1-standard-4"
65+
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
66+
ACCELERATOR_COUNT = 0
67+
TRAINING_FRACTION_SPLIT = 0.7
68+
TEST_FRACTION_SPLIT = 0.15
69+
VALIDATION_FRACTION_SPLIT = 0.15
70+
71+
PYTHON_PACKAGE_GCS_URI = os.environ.get("PYTHON_PACKAGE_GSC_URI", "path_to_test_model_in_arch")
72+
PYTHON_MODULE_NAME = "aiplatform_custom_trainer_script.task"
73+
74+
LOCAL_TRAINING_SCRIPT_PATH = os.environ.get("LOCAL_TRAINING_SCRIPT_PATH", "path_to_training_script")
75+
76+
TRAINING_PIPELINE_ID = "test-training-pipeline-id"
77+
CUSTOM_JOB_ID = "test-custom-job-id"
78+
79+
IMAGE_DATASET = {
80+
"display_name": str(uuid4()),
81+
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml",
82+
"metadata": "test-image-dataset",
83+
}
84+
TABULAR_DATASET = {
85+
"display_name": str(uuid4()),
86+
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml",
87+
"metadata": "test-tabular-dataset",
88+
}
89+
TEXT_DATASET = {
90+
"display_name": str(uuid4()),
91+
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml",
92+
"metadata": "test-text-dataset",
93+
}
94+
VIDEO_DATASET = {
95+
"display_name": str(uuid4()),
96+
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml",
97+
"metadata": "test-video-dataset",
98+
}
99+
TIME_SERIES_DATASET = {
100+
"display_name": str(uuid4()),
101+
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml",
102+
"metadata": "test-video-dataset",
103+
}
104+
DATASET_ID = os.environ.get("DATASET_ID", "test-dataset-id")
105+
TEST_EXPORT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://test-vertex-ai-bucket/exports"}}
106+
TEST_IMPORT_CONFIG = [
107+
{
108+
"data_item_labels": {
109+
"test-labels-name": "test-labels-value",
110+
},
111+
"import_schema_uri": (
112+
"gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml"
113+
),
114+
"gcs_source": {
115+
"uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"]
116+
},
117+
},
118+
]
119+
DATASET_TO_UPDATE = {"display_name": "test-name"}
120+
TEST_UPDATE_MASK = {"paths": ["displayName"]}
121+
122+
with models.DAG(
123+
"example_gcp_vertex_ai_custom_jobs",
124+
schedule_interval="@once",
125+
start_date=datetime(2021, 1, 1),
126+
catchup=False,
127+
) as custom_jobs_dag:
128+
# [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
129+
create_custom_container_training_job = CreateCustomContainerTrainingJobOperator(
130+
task_id="custom_container_task",
131+
staging_bucket=STAGING_BUCKET,
132+
display_name=f"train-housing-container-{DISPLAY_NAME}",
133+
container_uri=CUSTOM_CONTAINER_URI,
134+
model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
135+
# run params
136+
dataset_id=DATASET_ID,
137+
command=["python3", "task.py"],
138+
model_display_name=f"container-housing-model-{DISPLAY_NAME}",
139+
replica_count=REPLICA_COUNT,
140+
machine_type=MACHINE_TYPE,
141+
accelerator_type=ACCELERATOR_TYPE,
142+
accelerator_count=ACCELERATOR_COUNT,
143+
training_fraction_split=TRAINING_FRACTION_SPLIT,
144+
validation_fraction_split=VALIDATION_FRACTION_SPLIT,
145+
test_fraction_split=TEST_FRACTION_SPLIT,
146+
region=REGION,
147+
project_id=PROJECT_ID,
148+
)
149+
# [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
150+
151+
# [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
152+
create_custom_python_package_training_job = CreateCustomPythonPackageTrainingJobOperator(
153+
task_id="python_package_task",
154+
staging_bucket=STAGING_BUCKET,
155+
display_name=f"train-housing-py-package-{DISPLAY_NAME}",
156+
python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
157+
python_module_name=PYTHON_MODULE_NAME,
158+
container_uri=CONTAINER_URI,
159+
model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
160+
# run params
161+
dataset_id=DATASET_ID,
162+
model_display_name=f"py-package-housing-model-{DISPLAY_NAME}",
163+
replica_count=REPLICA_COUNT,
164+
machine_type=MACHINE_TYPE,
165+
accelerator_type=ACCELERATOR_TYPE,
166+
accelerator_count=ACCELERATOR_COUNT,
167+
training_fraction_split=TRAINING_FRACTION_SPLIT,
168+
validation_fraction_split=VALIDATION_FRACTION_SPLIT,
169+
test_fraction_split=TEST_FRACTION_SPLIT,
170+
region=REGION,
171+
project_id=PROJECT_ID,
172+
)
173+
# [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
174+
175+
# [START how_to_cloud_vertex_ai_create_custom_training_job_operator]
176+
create_custom_training_job = CreateCustomTrainingJobOperator(
177+
task_id="custom_task",
178+
staging_bucket=STAGING_BUCKET,
179+
display_name=f"train-housing-custom-{DISPLAY_NAME}",
180+
script_path=LOCAL_TRAINING_SCRIPT_PATH,
181+
container_uri=CONTAINER_URI,
182+
requirements=["gcsfs==0.7.1"],
183+
model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
184+
# run params
185+
dataset_id=DATASET_ID,
186+
replica_count=1,
187+
model_display_name=f"custom-housing-model-{DISPLAY_NAME}",
188+
sync=False,
189+
region=REGION,
190+
project_id=PROJECT_ID,
191+
)
192+
# [END how_to_cloud_vertex_ai_create_custom_training_job_operator]
193+
194+
# [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
195+
delete_custom_training_job = DeleteCustomTrainingJobOperator(
196+
task_id="delete_custom_training_job",
197+
training_pipeline_id=TRAINING_PIPELINE_ID,
198+
custom_job_id=CUSTOM_JOB_ID,
199+
region=REGION,
200+
project_id=PROJECT_ID,
201+
)
202+
# [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
203+
204+
# [START how_to_cloud_vertex_ai_list_custom_training_job_operator]
205+
list_custom_training_job = ListCustomTrainingJobOperator(
206+
task_id="list_custom_training_job",
207+
region=REGION,
208+
project_id=PROJECT_ID,
209+
)
210+
# [END how_to_cloud_vertex_ai_list_custom_training_job_operator]
211+
212+
with models.DAG(
213+
"example_gcp_vertex_ai_dataset",
214+
schedule_interval="@once",
215+
start_date=datetime(2021, 1, 1),
216+
catchup=False,
217+
) as dataset_dag:
218+
# [START how_to_cloud_vertex_ai_create_dataset_operator]
219+
create_image_dataset_job = CreateDatasetOperator(
220+
task_id="image_dataset",
221+
dataset=IMAGE_DATASET,
222+
region=REGION,
223+
project_id=PROJECT_ID,
224+
)
225+
create_tabular_dataset_job = CreateDatasetOperator(
226+
task_id="tabular_dataset",
227+
dataset=TABULAR_DATASET,
228+
region=REGION,
229+
project_id=PROJECT_ID,
230+
)
231+
create_text_dataset_job = CreateDatasetOperator(
232+
task_id="text_dataset",
233+
dataset=TEXT_DATASET,
234+
region=REGION,
235+
project_id=PROJECT_ID,
236+
)
237+
create_video_dataset_job = CreateDatasetOperator(
238+
task_id="video_dataset",
239+
dataset=VIDEO_DATASET,
240+
region=REGION,
241+
project_id=PROJECT_ID,
242+
)
243+
create_time_series_dataset_job = CreateDatasetOperator(
244+
task_id="time_series_dataset",
245+
dataset=TIME_SERIES_DATASET,
246+
region=REGION,
247+
project_id=PROJECT_ID,
248+
)
249+
# [END how_to_cloud_vertex_ai_create_dataset_operator]
250+
251+
# [START how_to_cloud_vertex_ai_delete_dataset_operator]
252+
delete_dataset_job = DeleteDatasetOperator(
253+
task_id="delete_dataset",
254+
dataset_id=create_text_dataset_job.output['dataset_id'],
255+
region=REGION,
256+
project_id=PROJECT_ID,
257+
)
258+
# [END how_to_cloud_vertex_ai_delete_dataset_operator]
259+
260+
# [START how_to_cloud_vertex_ai_get_dataset_operator]
261+
get_dataset = GetDatasetOperator(
262+
task_id="get_dataset",
263+
project_id=PROJECT_ID,
264+
region=REGION,
265+
dataset_id=create_tabular_dataset_job.output['dataset_id'],
266+
)
267+
# [END how_to_cloud_vertex_ai_get_dataset_operator]
268+
269+
# [START how_to_cloud_vertex_ai_export_data_operator]
270+
export_data_job = ExportDataOperator(
271+
task_id="export_data",
272+
dataset_id=create_image_dataset_job.output['dataset_id'],
273+
region=REGION,
274+
project_id=PROJECT_ID,
275+
export_config=TEST_EXPORT_CONFIG,
276+
)
277+
# [END how_to_cloud_vertex_ai_export_data_operator]
278+
279+
# [START how_to_cloud_vertex_ai_import_data_operator]
280+
import_data_job = ImportDataOperator(
281+
task_id="import_data",
282+
dataset_id=create_image_dataset_job.output['dataset_id'],
283+
region=REGION,
284+
project_id=PROJECT_ID,
285+
import_configs=TEST_IMPORT_CONFIG,
286+
)
287+
# [END how_to_cloud_vertex_ai_import_data_operator]
288+
289+
# [START how_to_cloud_vertex_ai_list_dataset_operator]
290+
list_dataset_job = ListDatasetsOperator(
291+
task_id="list_dataset",
292+
region=REGION,
293+
project_id=PROJECT_ID,
294+
)
295+
# [END how_to_cloud_vertex_ai_list_dataset_operator]
296+
297+
# [START how_to_cloud_vertex_ai_update_dataset_operator]
298+
update_dataset_job = UpdateDatasetOperator(
299+
task_id="update_dataset",
300+
project_id=PROJECT_ID,
301+
region=REGION,
302+
dataset_id=create_video_dataset_job.output['dataset_id'],
303+
dataset=DATASET_TO_UPDATE,
304+
update_mask=TEST_UPDATE_MASK,
305+
)
306+
# [END how_to_cloud_vertex_ai_update_dataset_operator]
307+
308+
create_time_series_dataset_job
309+
create_text_dataset_job >> delete_dataset_job
310+
create_tabular_dataset_job >> get_dataset
311+
create_image_dataset_job >> import_data_job >> export_data_job
312+
create_video_dataset_job >> update_dataset_job
313+
list_dataset_job
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

0 commit comments

Comments
 (0)