Skip to content

Commit b230566

Browse files
authored
Update example DAG for AI Platform operators (#9727)
1 parent 13a827d commit b230566

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

airflow/providers/google/cloud/example_dags/example_mlengine.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from airflow import models
2626
from airflow.operators.bash import BashOperator
2727
from airflow.providers.google.cloud.operators.mlengine import (
28-
MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator,
29-
MLEngineListVersionsOperator, MLEngineManageModelOperator, MLEngineSetDefaultVersionOperator,
30-
MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator,
28+
MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator,
29+
MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator,
30+
MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator,
31+
MLEngineStartTrainingJobOperator,
3132
)
3233
from airflow.providers.google.cloud.utils import mlengine_operator_utils
3334
from airflow.utils.dates import days_ago
@@ -66,30 +67,26 @@
6667
project_id=PROJECT_ID,
6768
region="us-central1",
6869
job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
69-
runtime_version="1.14",
70-
python_version="3.5",
70+
runtime_version="1.15",
71+
python_version="3.7",
7172
job_dir=JOB_DIR,
7273
package_uris=[TRAINER_URI],
7374
training_python_module=TRAINER_PY_MODULE,
7475
training_args=[],
7576
)
7677

77-
create_model = MLEngineManageModelOperator(
78+
create_model = MLEngineCreateModelOperator(
7879
task_id="create-model",
7980
project_id=PROJECT_ID,
80-
operation='create',
8181
model={
8282
"name": MODEL_NAME,
8383
},
8484
)
8585

86-
get_model = MLEngineManageModelOperator(
86+
get_model = MLEngineGetModelOperator(
8787
task_id="get-model",
8888
project_id=PROJECT_ID,
89-
operation="get",
90-
model={
91-
"name": MODEL_NAME,
92-
}
89+
model_name=MODEL_NAME,
9390
)
9491

9592
get_model_result = BashOperator(
@@ -105,10 +102,10 @@
105102
"name": "v1",
106103
"description": "First-version",
107104
"deployment_uri": '{}/keras_export/'.format(JOB_DIR),
108-
"runtime_version": "1.14",
105+
"runtime_version": "1.15",
109106
"machineType": "mls1-c1-m2",
110107
"framework": "TENSORFLOW",
111-
"pythonVersion": "3.5"
108+
"pythonVersion": "3.7"
112109
}
113110
)
114111

@@ -120,10 +117,10 @@
120117
"name": "v2",
121118
"description": "Second version",
122119
"deployment_uri": SAVED_MODEL_PATH,
123-
"runtime_version": "1.14",
120+
"runtime_version": "1.15",
124121
"machineType": "mls1-c1-m2",
125122
"framework": "TENSORFLOW",
126-
"pythonVersion": "3.5"
123+
"pythonVersion": "3.7"
127124
}
128125
)
129126

@@ -148,7 +145,7 @@
148145
prediction = MLEngineStartBatchPredictionJobOperator(
149146
task_id="prediction",
150147
project_id=PROJECT_ID,
151-
job_id="prediciton-{{ ts_nodash }}-{{ params.model_name }}",
148+
job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
152149
region="us-central1",
153150
model_name=MODEL_NAME,
154151
data_format="TEXT",
@@ -203,13 +200,13 @@ def validate_err_and_count(summary: Dict) -> Dict:
203200
return summary
204201

205202
evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
206-
task_prefix="evalueate-ops", # pylint: disable=too-many-arguments
203+
task_prefix="evaluate-ops",
207204
data_format="TEXT",
208205
input_paths=[PREDICTION_INPUT],
209206
prediction_path=PREDICTION_OUTPUT,
210207
metric_fn_and_keys=get_metric_fn_and_keys(),
211208
validate_fn=validate_err_and_count,
212-
batch_prediction_job_id="evalueate-ops-{{ ts_nodash }}-{{ params.model_name }}",
209+
batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
213210
project_id=PROJECT_ID,
214211
region="us-central1",
215212
dataflow_options={

0 commit comments

Comments
 (0)