-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Avoid the builtin numbers module. #144788
Description
🚀 The feature, motivation and pitch
Currently, torch uses the builtin numbers module in a few places (only ~40 hits). However, the numbers module is problematic for multiple reasons:
-
The
numbersmodule is incompatible with type annotations (see int is not a Number? python/mypy#3186, example: mypy-playground).- In particular, annotating function arguments with
numbers.Number,numbers.Real, etc. is a terrible idea. - Using runtime behavior like
isinstance(x, Number)forces us to addtype: ignorecomments inside the else branch. - In particular, this is a blocker to annotating the
torch.distributionsmodule ([typing] Add static type hints totorch.distributions. #144196), since this is the place where most of the uses ofnumbersare found, see: [typing] Add type hints to__init__methods intorch.distributions. #144197 (comment)
- In particular, annotating function arguments with
-
Since it's just an abstract base class, it requires users to do
Number.register(my_number_type)to ensureisinstancesucceeds. -
Internally,
torch.tensordoesn't seem to care if something is anumbers.Number, in fact, the supported types appear to be- symbolic torch scalars
torch.SymBool,torch.SymIntandtorch.SymFloat numpyscalarsnumpy.int32,numpy.int64,numpy.float32, etc.- python built-in scalars
bool,int,float,complex - things that can be converted to built-in scalars via
__bool__,__int__,__index__,__float__or__complex__(requires specifyingdtype)
(see
/torch/csrc/utils/tensor_new.cppandtorch/_refs/__init__.py)demo
import torch from numbers import Real class MyReal(Real): """Simple wrapper class for float.""" __slots__ = ("val") def __float__(self): return self.val.__float__() def __complex__(self): return self.val.__complex__() def __init__(self, x) -> None: self.val = float(x) @property def real(self): return MyReal(self.val.real) @property def imag(self): return MyReal(self.val.imag) def conjugate(self): return MyReal(self.val.conjugate()) def __abs__(self): return MyReal(self.val.__abs__()) def __neg__(self): return MyReal(self.val.__neg__()) def __pos__(self): return MyReal(self.val.__pos__()) def __trunc__(self): return MyReal(self.val.__trunc__()) def __floor__(self): return MyReal(self.val.__floor__()) def __ceil__(self): return MyReal(self.val.__ceil__()) def __round__(self, ndigits=None): return MyReal(self.val.__round__(ndigits=ndigits)) def __eq__(self, other): return MyReal(self.val.__eq__(other)) def __lt__(self, other): return MyReal(self.val.__lt__(other)) def __le__(self, other): return MyReal(self.val.__le__(other)) def __add__(self, other): return MyReal(self.val.__add__(other)) def __radd__(self, other): return MyReal(self.val.__radd__(other)) def __mul__(self, other): return MyReal(self.val.__mul__(other)) def __rmul__(self, other): return MyReal(self.val.__rmul__(other)) def __truediv__(self, other): return MyReal(self.val.__truediv__(other)) def __rtruediv__(self, other): return MyReal(self.val.__rtruediv__(other)) def __floordiv__(self, other): return MyReal(self.val.__floordiv__(other)) def __rfloordiv__(self, other): return MyReal(self.val.__rfloordiv__(other)) def __mod__(self, other): return MyReal(self.val.__mod__(other)) def __rmod__(self, other): return MyReal(self.val.__rmod__(other)) def __pow__(self, exponent): return MyReal(self.val.__pow__(exponent)) def __rpow__(self, base): return MyReal(self.val.__rmod__(base)) class Pi: def __float__(self) -> float: return 3.14 torch.tensor(MyReal(3.14), dtype=float) # ✅ torch.tensor(Pi(), dtype=float) # ✅ torch.tensor(MyReal(3.14)) # ❌ Runtimerror: Could not infer dtype of MyReal torch.tensor(Pi()) # ❌ Runtimerror: Could not infer dtype of Pi
- symbolic torch scalars
Alternatives
There are 3 main alternatives:
- Use
Uniontype of the supported types (tuplefor python 3.9).torchalready provides for example liketorch.types.Numberandtorch._prims_common.Number - Use builtin
Protocoltypes liketyping.SupportsFloat
- The main disadvantage here is that
Tensor, since it implements__float__, is aSupportsFloatitself, which could require changing some exisiting if-else tests.
- Provide a custom
Protocoltype.
Additional context
One concern could be speed of `isinstance(x, Number)`, below is a comparison between the approaches.
import torch
from numbers import Real
import numpy as np
from typing import SupportsFloat
T1 = Real
T2 = SupportsFloat
T3 = (bool, int, float, complex, torch.SymBool, torch.SymInt, torch.SymFloat, np.number)
print("Testing float")
x = 3.14
%timeit isinstance(x, T1) # 237 ns ± 0.374 ns
%timeit isinstance(x, T2) # 214 ns ± 0.325 ns
%timeit isinstance(x, T3) # 35 ns ± 0.844 ns
print("Testing np.float32")
y = np.float32(3.14)
%timeit isinstance(y, T1) # 106 ns ± 2.3 ns
%timeit isinstance(y, T2) # 223 ns ± 2.33 ns
%timeit isinstance(y, T3) # 104 ns ± 0.52 ns
print("Testing Tensor")
z = torch.tensor(3.14)
%timeit isinstance(z, T1) # 117 ns ± 0.962 ns
%timeit isinstance(z, T2) # 226 ns ± 0.508 ns
%timeit isinstance(z, T3) # 99.1 ns ± 0.699 ns
print("Testing string (non-match)")
w = "3.14"
%timeit isinstance(w, T1) # 114 ns ± 1.47 ns
%timeit isinstance(w, T2) # 2.21 μs ± 79.2 ns
%timeit isinstance(w, T3) # 95 ns ± 0.887 ns One can see that isinstance(val, SupportsFloat) is roughly twice as slow as isinstance(val, Real) for a positive, but can be a lot slower for a negative. The Union can be a lot faster, but the speed depends on the order of the members (if we put float last, the first run takes ~90ns, since the argument is checked sequentially against the provided types).
cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster