|
25 | 25 | from airflow import models |
26 | 26 | from airflow.operators.bash import BashOperator |
27 | 27 | from airflow.providers.google.cloud.operators.mlengine import ( |
28 | | - MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator, |
29 | | - MLEngineListVersionsOperator, MLEngineManageModelOperator, MLEngineSetDefaultVersionOperator, |
30 | | - MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator, |
| 28 | + MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, |
| 29 | + MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator, |
| 30 | + MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator, |
| 31 | + MLEngineStartTrainingJobOperator, |
31 | 32 | ) |
32 | 33 | from airflow.providers.google.cloud.utils import mlengine_operator_utils |
33 | 34 | from airflow.utils.dates import days_ago |
|
66 | 67 | project_id=PROJECT_ID, |
67 | 68 | region="us-central1", |
68 | 69 | job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}", |
69 | | - runtime_version="1.14", |
70 | | - python_version="3.5", |
| 70 | + runtime_version="1.15", |
| 71 | + python_version="3.7", |
71 | 72 | job_dir=JOB_DIR, |
72 | 73 | package_uris=[TRAINER_URI], |
73 | 74 | training_python_module=TRAINER_PY_MODULE, |
74 | 75 | training_args=[], |
75 | 76 | ) |
76 | 77 |
|
77 | | - create_model = MLEngineManageModelOperator( |
| 78 | + create_model = MLEngineCreateModelOperator( |
78 | 79 | task_id="create-model", |
79 | 80 | project_id=PROJECT_ID, |
80 | | - operation='create', |
81 | 81 | model={ |
82 | 82 | "name": MODEL_NAME, |
83 | 83 | }, |
84 | 84 | ) |
85 | 85 |
|
86 | | - get_model = MLEngineManageModelOperator( |
| 86 | + get_model = MLEngineGetModelOperator( |
87 | 87 | task_id="get-model", |
88 | 88 | project_id=PROJECT_ID, |
89 | | - operation="get", |
90 | | - model={ |
91 | | - "name": MODEL_NAME, |
92 | | - } |
| 89 | + model_name=MODEL_NAME, |
93 | 90 | ) |
94 | 91 |
|
95 | 92 | get_model_result = BashOperator( |
|
105 | 102 | "name": "v1", |
106 | 103 | "description": "First-version", |
107 | 104 | "deployment_uri": '{}/keras_export/'.format(JOB_DIR), |
108 | | - "runtime_version": "1.14", |
| 105 | + "runtime_version": "1.15", |
109 | 106 | "machineType": "mls1-c1-m2", |
110 | 107 | "framework": "TENSORFLOW", |
111 | | - "pythonVersion": "3.5" |
| 108 | + "pythonVersion": "3.7" |
112 | 109 | } |
113 | 110 | ) |
114 | 111 |
|
|
120 | 117 | "name": "v2", |
121 | 118 | "description": "Second version", |
122 | 119 | "deployment_uri": SAVED_MODEL_PATH, |
123 | | - "runtime_version": "1.14", |
| 120 | + "runtime_version": "1.15", |
124 | 121 | "machineType": "mls1-c1-m2", |
125 | 122 | "framework": "TENSORFLOW", |
126 | | - "pythonVersion": "3.5" |
| 123 | + "pythonVersion": "3.7" |
127 | 124 | } |
128 | 125 | ) |
129 | 126 |
|
|
148 | 145 | prediction = MLEngineStartBatchPredictionJobOperator( |
149 | 146 | task_id="prediction", |
150 | 147 | project_id=PROJECT_ID, |
151 | | - job_id="prediciton-{{ ts_nodash }}-{{ params.model_name }}", |
| 148 | + job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}", |
152 | 149 | region="us-central1", |
153 | 150 | model_name=MODEL_NAME, |
154 | 151 | data_format="TEXT", |
@@ -203,13 +200,13 @@ def validate_err_and_count(summary: Dict) -> Dict: |
203 | 200 | return summary |
204 | 201 |
|
205 | 202 | evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops( |
206 | | - task_prefix="evalueate-ops", # pylint: disable=too-many-arguments |
| 203 | + task_prefix="evaluate-ops", |
207 | 204 | data_format="TEXT", |
208 | 205 | input_paths=[PREDICTION_INPUT], |
209 | 206 | prediction_path=PREDICTION_OUTPUT, |
210 | 207 | metric_fn_and_keys=get_metric_fn_and_keys(), |
211 | 208 | validate_fn=validate_err_and_count, |
212 | | - batch_prediction_job_id="evalueate-ops-{{ ts_nodash }}-{{ params.model_name }}", |
| 209 | + batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", |
213 | 210 | project_id=PROJECT_ID, |
214 | 211 | region="us-central1", |
215 | 212 | dataflow_options={ |
|
0 commit comments