Skip to content

Commit c526fb6

Browse files
author
Bas van Beek
committed
ENH,WIP: Add a mypy plugin for casting np.number instances to appropiate subclasses
1 parent dbf2018 commit c526fb6

2 files changed

Lines changed: 60 additions & 1 deletion

File tree

number_plugin.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import typing as t
2+
3+
import mypy.types
4+
from mypy.types import Type, Instance
5+
from mypy.plugin import Plugin, MethodContext, FunctionContext
6+
7+
_NAMES: t.Mapping[str, str] = {
8+
"numpy.signedinteger": "numpy.int",
9+
"numpy.unsignedinteger": "numpy.uint",
10+
"numpy.floating": "numpy.float",
11+
"numpy.complexfloating": "numpy.complex",
12+
}
13+
14+
_PRECISION: t.Mapping[str, int] = {
15+
"numpy._64Bit": 64,
16+
"numpy._32Bit": 32,
17+
"numpy._16Bit": 16,
18+
"numpy._8Bit": 8,
19+
}
20+
21+
22+
def _hook(ctx: t.Union[FunctionContext, MethodContext]) -> Type:
23+
api = ctx.api
24+
ret_type = ctx.default_return_type
25+
if not isinstance(ret_type, Instance):
26+
return ret_type
27+
28+
# There are 3 dict lookups where a `KeyError` could potentially be raised
29+
# If this hapens, return the original `ret_type` in unaltered form
30+
try:
31+
name = _NAMES[ret_type.type.fullname] # dict lookup #1
32+
33+
# Parse the precision
34+
_precision = ret_type.args[0]
35+
if not isinstance(_precision, Instance):
36+
return ret_type
37+
precision = _PRECISION[_precision.type.fullname] # dict lookup #2
38+
39+
if name == "numpy.complex":
40+
precision *= 2
41+
42+
return api.named_type(f'{name}{precision}') # Dict lookup #3
43+
except KeyError:
44+
return ret_type
45+
46+
47+
class NumberPlugin(Plugin):
48+
def get_method_hook(self, fullname: str
49+
) -> t.Optional[t.Callable[[MethodContext], Type]]:
50+
return _hook
51+
52+
def get_function_hook(self, fullname: str
53+
) -> t.Optional[t.Callable[[FunctionContext], Type]]:
54+
return _hook
55+
56+
57+
def plugin(version: str) -> t.Type[NumberPlugin]:
58+
return NumberPlugin

numpy/typing/tests/data/mypy.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[mypy]
2-
mypy_path = ../../..
2+
mypy_path = ../../../..
3+
plugins = number_plugin
34

45
[mypy-numpy]
56
ignore_errors = True

0 commit comments

Comments
 (0)