|
16 | 16 | # KIND, either express or implied. See the License for the |
17 | 17 | # specific language governing permissions and limitations |
18 | 18 | # under the License. |
| 19 | +""" |
| 20 | +A template called by DataFlowPythonOperator to summarize BatchPrediction. |
19 | 21 |
|
20 | | -"""A template called by DataFlowPythonOperator to summarize BatchPrediction. |
21 | 22 | It accepts a user function to calculate the metric(s) per instance in |
22 | 23 | the prediction results, then aggregates to output as a summary. |
23 | | -Args: |
24 | | - --prediction_path: |
25 | | - The GCS folder that contains BatchPrediction results, containing |
26 | | - prediction.results-NNNNN-of-NNNNN files in the json format. |
27 | | - Output will be also stored in this folder, as 'prediction.summary.json'. |
28 | | - --metric_fn_encoded: |
29 | | - An encoded function that calculates and returns a tuple of metric(s) |
30 | | - for a given instance (as a dictionary). It should be encoded |
31 | | - via base64.b64encode(dill.dumps(fn, recurse=True)). |
32 | | - --metric_keys: |
33 | | - A comma-separated key(s) of the aggregated metric(s) in the summary |
34 | | - output. The order and the size of the keys must match to the output |
35 | | - of metric_fn. |
36 | | - The summary will have an additional key, 'count', to represent the |
37 | | - total number of instances, so the keys shouldn't include 'count'. |
38 | | -# Usage example: |
39 | | -from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator |
40 | | -def get_metric_fn(): |
41 | | - import math # all imports must be outside of the function to be passed. |
42 | | - def metric_fn(inst): |
43 | | - label = float(inst["input_label"]) |
44 | | - classes = float(inst["classes"]) |
45 | | - prediction = float(inst["scores"][1]) |
46 | | - log_loss = math.log(1 + math.exp( |
47 | | - -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) |
48 | | - squared_err = (classes-label)**2 |
49 | | - return (log_loss, squared_err) |
50 | | - return metric_fn |
51 | | -metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) |
52 | | -DataflowCreatePythonJobOperator( |
53 | | - task_id="summary-prediction", |
54 | | - py_options=["-m"], |
55 | | - py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
56 | | - options={ |
57 | | - "prediction_path": prediction_path, |
58 | | - "metric_fn_encoded": metric_fn_encoded, |
59 | | - "metric_keys": "log_loss,mse" |
60 | | - }, |
61 | | - dataflow_default_options={ |
62 | | - "project": "xxx", "region": "us-east1", |
63 | | - "staging_location": "gs://yy", "temp_location": "gs://zz", |
64 | | - }) |
65 | | - >> dag |
66 | | -# When the input file is like the following: |
67 | | -{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} |
68 | | -{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} |
69 | | -{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} |
70 | | -{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} |
71 | | -# The output file will be: |
72 | | -{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} |
73 | | -# To test outside of the dag: |
74 | | -subprocess.check_call(["python", |
75 | | - "-m", |
76 | | - "airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
77 | | - "--prediction_path=gs://...", |
78 | | - "--metric_fn_encoded=" + metric_fn_encoded, |
79 | | - "--metric_keys=log_loss,mse", |
80 | | - "--runner=DataflowRunner", |
81 | | - "--staging_location=gs://...", |
82 | | - "--temp_location=gs://...", |
83 | | - ]) |
| 24 | +
|
| 25 | +It accepts the following arguments: |
| 26 | +
|
| 27 | +- ``--prediction_path``: |
| 28 | + The GCS folder that contains BatchPrediction results, containing |
| 29 | + prediction.results-NNNNN-of-NNNNN files in the json format. |
| 30 | + Output will be also stored in this folder, as 'prediction.summary.json'. |
| 31 | +- ``--metric_fn_encoded``: |
| 32 | + An encoded function that calculates and returns a tuple of metric(s) |
| 33 | + for a given instance (as a dictionary). It should be encoded |
| 34 | + via base64.b64encode(dill.dumps(fn, recurse=True)). |
| 35 | +- ``--metric_keys``: |
| 36 | + A comma-separated key(s) of the aggregated metric(s) in the summary |
| 37 | + output. The order and the size of the keys must match to the output |
| 38 | + of metric_fn. |
| 39 | + The summary will have an additional key, 'count', to represent the |
| 40 | + total number of instances, so the keys shouldn't include 'count'. |
| 41 | +
|
| 42 | +
|
| 43 | +Usage example: |
| 44 | +
|
| 45 | +.. code-block: python |
| 46 | +
|
| 47 | + from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator |
| 48 | +
|
| 49 | +
|
| 50 | + def get_metric_fn(): |
| 51 | + import math # all imports must be outside of the function to be passed. |
| 52 | + def metric_fn(inst): |
| 53 | + label = float(inst["input_label"]) |
| 54 | + classes = float(inst["classes"]) |
| 55 | + prediction = float(inst["scores"][1]) |
| 56 | + log_loss = math.log(1 + math.exp( |
| 57 | + -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) |
| 58 | + squared_err = (classes-label)**2 |
| 59 | + return (log_loss, squared_err) |
| 60 | + return metric_fn |
| 61 | + metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) |
| 62 | + DataflowCreatePythonJobOperator( |
| 63 | + task_id="summary-prediction", |
| 64 | + py_options=["-m"], |
| 65 | + py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
| 66 | + options={ |
| 67 | + "prediction_path": prediction_path, |
| 68 | + "metric_fn_encoded": metric_fn_encoded, |
| 69 | + "metric_keys": "log_loss,mse" |
| 70 | + }, |
| 71 | + dataflow_default_options={ |
| 72 | + "project": "xxx", "region": "us-east1", |
| 73 | + "staging_location": "gs://yy", "temp_location": "gs://zz", |
| 74 | + } |
| 75 | + ) >> dag |
| 76 | +
|
| 77 | +When the input file is like the following:: |
| 78 | +
|
| 79 | + {"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} |
| 80 | + {"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} |
| 81 | + {"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} |
| 82 | + {"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} |
| 83 | +
|
| 84 | +The output file will be:: |
| 85 | +
|
| 86 | + {"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} |
| 87 | +
|
| 88 | +To test outside of the dag: |
| 89 | +
|
| 90 | +.. code-block:: python |
| 91 | +
|
| 92 | + subprocess.check_call(["python", |
| 93 | + "-m", |
| 94 | + "airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
| 95 | + "--prediction_path=gs://...", |
| 96 | + "--metric_fn_encoded=" + metric_fn_encoded, |
| 97 | + "--metric_keys=log_loss,mse", |
| 98 | + "--runner=DataflowRunner", |
| 99 | + "--staging_location=gs://...", |
| 100 | + "--temp_location=gs://...", |
| 101 | + ]) |
84 | 102 | """ |
85 | 103 |
|
86 | 104 | import argparse |
|
0 commit comments