Skip to content

Commit 2bfc53b

Browse files
authored
Fix doc errors in google provider files. (#11713)
These files aren't _currently_ rendered/parsed by autoapi, but I was exploring making them parseable and ran in to some sphinx formatting errors. The `Args:` change is because pydocstyle thinks that is a special word, but we don't want it to be.
1 parent 53e6062 commit 2bfc53b

File tree

3 files changed

+107
-84
lines changed

3 files changed

+107
-84
lines changed

airflow/providers/google/cloud/utils/mlengine_operator_utils.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def create_evaluate_ops( # pylint: disable=too-many-arguments
7878
7979
Callers will provide two python callables, metric_fn and validate_fn, in
8080
order to customize the evaluation behavior as they wish.
81+
8182
- metric_fn receives a dictionary per instance derived from json in the
8283
batch prediction result. The keys might vary depending on the model.
8384
It should return a tuple of metrics.
@@ -93,24 +94,26 @@ def create_evaluate_ops( # pylint: disable=too-many-arguments
9394
9495
Typical examples are like this:
9596
96-
def get_metric_fn_and_keys():
97-
import math # imports should be outside of the metric_fn below.
98-
def error_and_squared_error(inst):
99-
label = float(inst['input_label'])
100-
classes = float(inst['classes']) # 0 or 1
101-
err = abs(classes-label)
102-
squared_err = math.pow(classes-label, 2)
103-
return (err, squared_err) # returns a tuple.
104-
return error_and_squared_error, ['err', 'mse'] # key order must match.
105-
106-
def validate_err_and_count(summary):
107-
if summary['err'] > 0.2:
108-
raise ValueError('Too high err>0.2; summary=%s' % summary)
109-
if summary['mse'] > 0.05:
110-
raise ValueError('Too high mse>0.05; summary=%s' % summary)
111-
if summary['count'] < 1000:
112-
raise ValueError('Too few instances<1000; summary=%s' % summary)
113-
return summary
97+
.. code-block:: python
98+
99+
def get_metric_fn_and_keys():
100+
import math # imports should be outside of the metric_fn below.
101+
def error_and_squared_error(inst):
102+
label = float(inst['input_label'])
103+
classes = float(inst['classes']) # 0 or 1
104+
err = abs(classes-label)
105+
squared_err = math.pow(classes-label, 2)
106+
return (err, squared_err) # returns a tuple.
107+
return error_and_squared_error, ['err', 'mse'] # key order must match.
108+
109+
def validate_err_and_count(summary):
110+
if summary['err'] > 0.2:
111+
raise ValueError('Too high err>0.2; summary=%s' % summary)
112+
if summary['mse'] > 0.05:
113+
raise ValueError('Too high mse>0.05; summary=%s' % summary)
114+
if summary['count'] < 1000:
115+
raise ValueError('Too few instances<1000; summary=%s' % summary)
116+
return summary
114117
115118
For the details on the other BatchPrediction-related arguments (project_id,
116119
job_id, region, data_format, input_paths, prediction_path, model_uri),
@@ -131,8 +134,10 @@ def validate_err_and_count(summary):
131134
:type prediction_path: str
132135
133136
:param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
137+
134138
- metric_fn is a function that accepts a dictionary (for an instance),
135139
and returns a tuple of metric(s) that it calculates.
140+
136141
- metric_keys is a list of strings to denote the key of each metric.
137142
:type metric_fn_and_keys: tuple of a function and a list[str]
138143

airflow/providers/google/cloud/utils/mlengine_prediction_summary.py

Lines changed: 80 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,71 +16,89 @@
1616
# KIND, either express or implied. See the License for the
1717
# specific language governing permissions and limitations
1818
# under the License.
19+
"""
20+
A template called by DataFlowPythonOperator to summarize BatchPrediction.
1921
20-
"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
2122
It accepts a user function to calculate the metric(s) per instance in
2223
the prediction results, then aggregates to output as a summary.
23-
Args:
24-
--prediction_path:
25-
The GCS folder that contains BatchPrediction results, containing
26-
prediction.results-NNNNN-of-NNNNN files in the json format.
27-
Output will be also stored in this folder, as 'prediction.summary.json'.
28-
--metric_fn_encoded:
29-
An encoded function that calculates and returns a tuple of metric(s)
30-
for a given instance (as a dictionary). It should be encoded
31-
via base64.b64encode(dill.dumps(fn, recurse=True)).
32-
--metric_keys:
33-
A comma-separated key(s) of the aggregated metric(s) in the summary
34-
output. The order and the size of the keys must match to the output
35-
of metric_fn.
36-
The summary will have an additional key, 'count', to represent the
37-
total number of instances, so the keys shouldn't include 'count'.
38-
# Usage example:
39-
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
40-
def get_metric_fn():
41-
import math # all imports must be outside of the function to be passed.
42-
def metric_fn(inst):
43-
label = float(inst["input_label"])
44-
classes = float(inst["classes"])
45-
prediction = float(inst["scores"][1])
46-
log_loss = math.log(1 + math.exp(
47-
-(label * 2 - 1) * math.log(prediction / (1 - prediction))))
48-
squared_err = (classes-label)**2
49-
return (log_loss, squared_err)
50-
return metric_fn
51-
metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
52-
DataflowCreatePythonJobOperator(
53-
task_id="summary-prediction",
54-
py_options=["-m"],
55-
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
56-
options={
57-
"prediction_path": prediction_path,
58-
"metric_fn_encoded": metric_fn_encoded,
59-
"metric_keys": "log_loss,mse"
60-
},
61-
dataflow_default_options={
62-
"project": "xxx", "region": "us-east1",
63-
"staging_location": "gs://yy", "temp_location": "gs://zz",
64-
})
65-
>> dag
66-
# When the input file is like the following:
67-
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
68-
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
69-
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
70-
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
71-
# The output file will be:
72-
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
73-
# To test outside of the dag:
74-
subprocess.check_call(["python",
75-
"-m",
76-
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
77-
"--prediction_path=gs://...",
78-
"--metric_fn_encoded=" + metric_fn_encoded,
79-
"--metric_keys=log_loss,mse",
80-
"--runner=DataflowRunner",
81-
"--staging_location=gs://...",
82-
"--temp_location=gs://...",
83-
])
24+
25+
It accepts the following arguments:
26+
27+
- ``--prediction_path``:
28+
The GCS folder that contains BatchPrediction results, containing
29+
prediction.results-NNNNN-of-NNNNN files in the json format.
30+
Output will be also stored in this folder, as 'prediction.summary.json'.
31+
- ``--metric_fn_encoded``:
32+
An encoded function that calculates and returns a tuple of metric(s)
33+
for a given instance (as a dictionary). It should be encoded
34+
via base64.b64encode(dill.dumps(fn, recurse=True)).
35+
- ``--metric_keys``:
36+
A comma-separated key(s) of the aggregated metric(s) in the summary
37+
output. The order and the size of the keys must match to the output
38+
of metric_fn.
39+
The summary will have an additional key, 'count', to represent the
40+
total number of instances, so the keys shouldn't include 'count'.
41+
42+
43+
Usage example:
44+
45+
.. code-block: python
46+
47+
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
48+
49+
50+
def get_metric_fn():
51+
import math # all imports must be outside of the function to be passed.
52+
def metric_fn(inst):
53+
label = float(inst["input_label"])
54+
classes = float(inst["classes"])
55+
prediction = float(inst["scores"][1])
56+
log_loss = math.log(1 + math.exp(
57+
-(label * 2 - 1) * math.log(prediction / (1 - prediction))))
58+
squared_err = (classes-label)**2
59+
return (log_loss, squared_err)
60+
return metric_fn
61+
metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
62+
DataflowCreatePythonJobOperator(
63+
task_id="summary-prediction",
64+
py_options=["-m"],
65+
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
66+
options={
67+
"prediction_path": prediction_path,
68+
"metric_fn_encoded": metric_fn_encoded,
69+
"metric_keys": "log_loss,mse"
70+
},
71+
dataflow_default_options={
72+
"project": "xxx", "region": "us-east1",
73+
"staging_location": "gs://yy", "temp_location": "gs://zz",
74+
}
75+
) >> dag
76+
77+
When the input file is like the following::
78+
79+
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
80+
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
81+
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
82+
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
83+
84+
The output file will be::
85+
86+
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
87+
88+
To test outside of the dag:
89+
90+
.. code-block:: python
91+
92+
subprocess.check_call(["python",
93+
"-m",
94+
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
95+
"--prediction_path=gs://...",
96+
"--metric_fn_encoded=" + metric_fn_encoded,
97+
"--metric_keys=log_loss,mse",
98+
"--runner=DataflowRunner",
99+
"--staging_location=gs://...",
100+
"--temp_location=gs://...",
101+
])
84102
"""
85103

86104
import argparse

airflow/providers/google/common/utils/id_token_credentials.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def _load_credentials_from_file(
6262
6363
:param filename: The full path to the credentials file.
6464
:type filename: str
65-
:return Loaded credentials
66-
:rtype google.auth.credentials.Credentials
65+
:return: Loaded credentials
66+
:rtype: google.auth.credentials.Credentials
6767
:raise google.auth.exceptions.DefaultCredentialsError: if the file is in the wrong format or is missing.
6868
"""
6969
if not os.path.exists(filename):
@@ -184,8 +184,8 @@ def get_default_id_token_credentials(
184184
is running on Compute Engine. If not specified, then it will use the standard library http client
185185
to make requests.
186186
:type request: google.auth.transport.Request
187-
:return the current environment's credentials.
188-
:rtype google.auth.credentials.Credentials
187+
:return: the current environment's credentials.
188+
:rtype: google.auth.credentials.Credentials
189189
:raises ~google.auth.exceptions.DefaultCredentialsError:
190190
If no credentials were found, or if the credentials found were invalid.
191191
"""

0 commit comments

Comments
 (0)