Skip to content

Commit e1978a4

Browse files
Add typing for some internal python files. (#31514)
* Add typing for some internal python files.
1 parent 90d8754 commit e1978a4

10 files changed

Lines changed: 359 additions & 186 deletions

File tree

setup.cfg

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,23 @@ license_files = LICENSE
2121

2222
# NOTE(lidiz) Adding examples one by one due to pytype aggressive errer:
2323
# ninja: error: build.ninja:178: multiple rules generate helloworld_pb2.pyi [-w dupbuild=err]
24+
# TODO(xuanwn): include all files in src/python/grpcio/grpc
2425
[pytype]
2526
inputs =
2627
src/python/grpcio/grpc/experimental
28+
src/python/grpcio/grpc
2729
src/python/grpcio_tests/tests_aio
2830
examples/python/auth
2931
examples/python/helloworld
3032
exclude =
3133
**/*_pb2.py
34+
src/python/grpcio/grpc/framework
35+
src/python/grpcio/grpc/aio
36+
src/python/grpcio/grpc/beta
37+
src/python/grpcio/grpc/__init__.py
38+
src/python/grpcio/grpc/_channel.py
39+
src/python/grpcio/grpc/_server.py
40+
src/python/grpcio/grpc/_simple_stubs.py
3241

3342
# NOTE(lidiz)
3443
# import-error: C extension triggers import-error.

src/python/grpcio/grpc/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ py_library(
8989
srcs = ["_runtime_protos.py"],
9090
)
9191

92+
py_library(
93+
name = "_typing",
94+
srcs = ["_typing.py"],
95+
)
96+
9297
py_library(
9398
name = "grpcio",
9499
srcs = ["__init__.py"],
@@ -99,6 +104,7 @@ py_library(
99104
deps = [
100105
":_runtime_protos",
101106
":_simple_stubs",
107+
":_typing",
102108
":aio",
103109
":auth",
104110
":channel",

src/python/grpcio/grpc/_auth.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,39 @@
1414
"""GRPCAuthMetadataPlugins for standard authentication."""
1515

1616
import inspect
17+
from typing import Any, Optional
1718

1819
import grpc
1920

2021

21-
def _sign_request(callback, token, error):
22+
def _sign_request(callback: grpc.AuthMetadataPluginCallback,
23+
token: Optional[str], error: Optional[Exception]):
2224
metadata = (('authorization', 'Bearer {}'.format(token)),)
2325
callback(metadata, error)
2426

2527

2628
class GoogleCallCredentials(grpc.AuthMetadataPlugin):
2729
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""
30+
_is_jwt: bool
31+
_credentials: Any
2832

29-
def __init__(self, credentials):
33+
# TODO(xuanwn): Give credentials an actual type.
34+
def __init__(self, credentials: Any):
3035
self._credentials = credentials
3136
# Hack to determine if these are JWT creds and we need to pass
3237
# additional_claims when getting a token
3338
self._is_jwt = 'additional_claims' in inspect.getfullargspec(
3439
credentials.get_access_token).args
3540

36-
def __call__(self, context, callback):
41+
def __call__(self, context: grpc.AuthMetadataContext,
42+
callback: grpc.AuthMetadataPluginCallback):
3743
try:
3844
if self._is_jwt:
3945
access_token = self._credentials.get_access_token(
4046
additional_claims={
41-
'aud': context.service_url
47+
'aud':
48+
context.
49+
service_url # pytype: disable=attribute-error
4250
}).access_token
4351
else:
4452
access_token = self._credentials.get_access_token().access_token
@@ -50,9 +58,11 @@ def __call__(self, context, callback):
5058

5159
class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
5260
"""Metadata wrapper for raw access token credentials."""
61+
_access_token: str
5362

54-
def __init__(self, access_token):
63+
def __init__(self, access_token: str):
5564
self._access_token = access_token
5665

57-
def __call__(self, context, callback):
66+
def __call__(self, context: grpc.AuthMetadataContext,
67+
callback: grpc.AuthMetadataPluginCallback):
5868
_sign_request(callback, self._access_token, None)

src/python/grpcio/grpc/_common.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import logging
1717
import time
18+
from typing import Any, AnyStr, Callable, Optional, Union
1819

1920
import grpc
2021
from grpc._cython import cygrpc
22+
from grpc._typing import DeserializingFunction
23+
from grpc._typing import SerializingFunction
2124

2225
_LOGGER = logging.getLogger(__name__)
2326

@@ -64,20 +67,22 @@
6467
'GRPC_VERBOSITY=debug environment variable to see detailed error message.'
6568

6669

67-
def encode(s):
70+
def encode(s: AnyStr) -> bytes:
6871
if isinstance(s, bytes):
6972
return s
7073
else:
7174
return s.encode('utf8')
7275

7376

74-
def decode(b):
77+
def decode(b: AnyStr) -> str:
7578
if isinstance(b, bytes):
7679
return b.decode('utf-8', 'replace')
7780
return b
7881

7982

80-
def _transform(message, transformer, exception_message):
83+
def _transform(message: Any, transformer: Union[SerializingFunction,
84+
DeserializingFunction, None],
85+
exception_message: str) -> Any:
8186
if transformer is None:
8287
return message
8388
else:
@@ -88,26 +93,31 @@ def _transform(message, transformer, exception_message):
8893
return None
8994

9095

91-
def serialize(message, serializer):
96+
def serialize(message: Any, serializer: Optional[SerializingFunction]) -> bytes:
9297
return _transform(message, serializer, 'Exception serializing message!')
9398

9499

95-
def deserialize(serialized_message, deserializer):
100+
def deserialize(serialized_message: bytes,
101+
deserializer: Optional[DeserializingFunction]) -> Any:
96102
return _transform(serialized_message, deserializer,
97103
'Exception deserializing message!')
98104

99105

100-
def fully_qualified_method(group, method):
106+
def fully_qualified_method(group: str, method: str) -> str:
101107
return '/{}/{}'.format(group, method)
102108

103109

104-
def _wait_once(wait_fn, timeout, spin_cb):
110+
def _wait_once(wait_fn: Callable[..., None], timeout: float,
111+
spin_cb: Optional[Callable[[], None]]):
105112
wait_fn(timeout=timeout)
106113
if spin_cb is not None:
107114
spin_cb()
108115

109116

110-
def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
117+
def wait(wait_fn: Callable[..., None],
118+
wait_complete_fn: Callable[[], bool],
119+
timeout: Optional[float] = None,
120+
spin_cb: Optional[Callable[[], None]] = None) -> bool:
111121
"""Blocks waiting for an event without blocking the thread indefinitely.
112122
113123
See https://github.com/grpc/grpc/issues/19464 for full context. CPython's
@@ -148,7 +158,7 @@ def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
148158
return False
149159

150160

151-
def validate_port_binding_result(address, port):
161+
def validate_port_binding_result(address: str, port: int) -> int:
152162
"""Validates if the port binding succeed.
153163
154164
If the port returned by Core is 0, the binding is failed. However, in that

src/python/grpcio/grpc/_compression.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import Optional
18+
19+
import grpc
1520
from grpc._cython import cygrpc
21+
from grpc._typing import MetadataType
1622

1723
NoCompression = cygrpc.CompressionAlgorithm.none
1824
Deflate = cygrpc.CompressionAlgorithm.deflate
@@ -25,21 +31,23 @@
2531
}
2632

2733

28-
def _compression_algorithm_to_metadata_value(compression):
34+
def _compression_algorithm_to_metadata_value(
35+
compression: grpc.Compression) -> str:
2936
return _METADATA_STRING_MAPPING[compression]
3037

3138

32-
def compression_algorithm_to_metadata(compression):
39+
def compression_algorithm_to_metadata(compression: grpc.Compression):
3340
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
3441
_compression_algorithm_to_metadata_value(compression))
3542

3643

37-
def create_channel_option(compression):
44+
def create_channel_option(compression: Optional[grpc.Compression]):
3845
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
3946
int(compression)),) if compression else ()
4047

4148

42-
def augment_metadata(metadata, compression):
49+
def augment_metadata(metadata: Optional[MetadataType],
50+
compression: Optional[grpc.Compression]):
4351
if not metadata and not compression:
4452
return None
4553
base_metadata = tuple(metadata) if metadata else ()

0 commit comments

Comments
 (0)