Skip to content

Avoid the builtin numbers module. #144788

@randolf-scholz

Description

@randolf-scholz

🚀 The feature, motivation and pitch

Currently, torch uses the builtin numbers module in a few places (only ~40 hits). However, the numbers module is problematic for multiple reasons:

  1. The numbers module is incompatible with type annotations (see int is not a Number? python/mypy#3186, example: mypy-playground).

  2. Since it's just an abstract base class, it requires users to do Number.register(my_number_type) to ensure isinstance succeeds.

  3. Internally, torch.tensor doesn't seem to care if something is a numbers.Number, in fact, the supported types appear to be

    • symbolic torch scalars torch.SymBool, torch.SymInt and torch.SymFloat
    • numpy scalars numpy.int32, numpy.int64, numpy.float32, etc.
    • python built-in scalars bool, int, float, complex
    • things that can be converted to built-in scalars via __bool__, __int__, __index__, __float__ or __complex__ (requires specifying dtype)

    (see /torch/csrc/utils/tensor_new.cpp and torch/_refs/__init__.py)

    demo
    import torch
    from numbers import Real
    
    
    class MyReal(Real):
        """Simple wrapper class for float."""
        
        __slots__ = ("val")
    
        def __float__(self): return self.val.__float__()
        def __complex__(self): return  self.val.__complex__()
        
        def __init__(self, x) -> None:
            self.val = float(x)
        
        @property
        def real(self): return MyReal(self.val.real)
        @property
        def imag(self): return MyReal(self.val.imag)
        def conjugate(self): return MyReal(self.val.conjugate())
    
        def __abs__(self): return MyReal(self.val.__abs__())
        def __neg__(self): return MyReal(self.val.__neg__())
        def __pos__(self): return MyReal(self.val.__pos__())
        def __trunc__(self): return MyReal(self.val.__trunc__())
        def __floor__(self): return MyReal(self.val.__floor__())
        def __ceil__(self): return MyReal(self.val.__ceil__())
        def __round__(self, ndigits=None): return MyReal(self.val.__round__(ndigits=ndigits))
    
        def __eq__(self, other): return MyReal(self.val.__eq__(other))
        def __lt__(self, other): return MyReal(self.val.__lt__(other))
        def __le__(self, other): return MyReal(self.val.__le__(other))
        
        def __add__(self, other):  return MyReal(self.val.__add__(other))
        def __radd__(self, other):  return MyReal(self.val.__radd__(other))
        def __mul__(self, other):  return MyReal(self.val.__mul__(other))
        def __rmul__(self, other):  return MyReal(self.val.__rmul__(other))
        def __truediv__(self, other):  return MyReal(self.val.__truediv__(other))
        def __rtruediv__(self, other):  return MyReal(self.val.__rtruediv__(other))
        def __floordiv__(self, other): return MyReal(self.val.__floordiv__(other))
        def __rfloordiv__(self, other): return MyReal(self.val.__rfloordiv__(other))
        def __mod__(self, other): return MyReal(self.val.__mod__(other))
        def __rmod__(self, other): return MyReal(self.val.__rmod__(other))
    
        def __pow__(self, exponent): return MyReal(self.val.__pow__(exponent))
        def __rpow__(self, base): return MyReal(self.val.__rmod__(base))
    
    class Pi:
        def __float__(self) -> float: return 3.14
    
    torch.tensor(MyReal(3.14), dtype=float)  # ✅
    torch.tensor(Pi(), dtype=float)  # ✅
    
    torch.tensor(MyReal(3.14))  # ❌ Runtimerror: Could not infer dtype of MyReal
    torch.tensor(Pi())  # ❌ Runtimerror: Could not infer dtype of Pi    

Alternatives

There are 3 main alternatives:

  1. Use Union type of the supported types (tuple for python 3.9). torch already provides for example like torch.types.Number and torch._prims_common.Number
  2. Use builtin Protocol types like typing.SupportsFloat
  • The main disadvantage here is that Tensor, since it implements __float__, is a SupportsFloat itself, which could require changing some exisiting if-else tests.
  1. Provide a custom Protocol type.

Additional context

One concern could be speed of `isinstance(x, Number)`, below is a comparison between the approaches.
import torch
from numbers import Real
import numpy as np
from typing import SupportsFloat

T1 = Real
T2 = SupportsFloat
T3 = (bool, int, float, complex, torch.SymBool, torch.SymInt, torch.SymFloat, np.number)

print("Testing float")
x = 3.14
%timeit isinstance(x, T1)  # 237 ns ± 0.374 ns 
%timeit isinstance(x, T2)  # 214 ns ± 0.325 ns 
%timeit isinstance(x, T3)  #  35 ns ± 0.844 ns
print("Testing np.float32")
y = np.float32(3.14)
%timeit isinstance(y, T1)  # 106 ns ± 2.3 ns 
%timeit isinstance(y, T2)  # 223 ns ± 2.33 ns
%timeit isinstance(y, T3)  # 104 ns ± 0.52 ns
print("Testing Tensor")
z = torch.tensor(3.14)
%timeit isinstance(z, T1)  # 117 ns ± 0.962 ns 
%timeit isinstance(z, T2)  # 226 ns ± 0.508 ns 
%timeit isinstance(z, T3)  # 99.1 ns ± 0.699 ns
print("Testing string (non-match)")
w = "3.14"
%timeit isinstance(w, T1)  # 114 ns ± 1.47 ns 
%timeit isinstance(w, T2)  # 2.21 μs ± 79.2 ns
%timeit isinstance(w, T3)  # 95 ns ± 0.887 ns 

One can see that isinstance(val, SupportsFloat) is roughly twice as slow as isinstance(val, Real) for a positive, but can be a lot slower for a negative. The Union can be a lot faster, but the speed depends on the order of the members (if we put float last, the first run takes ~90ns, since the argument is checked sequentially against the provided types).

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: distributionsRelated to torch.distributionsmodule: typingRelated to mypy type annotationstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions