-
-
Notifications
You must be signed in to change notification settings - Fork 34.3k
Open
Labels
Description
It appears that when the type of an object has JAX_TPFLAGS_HAVE_VECTORCALL set on tp_flags, then overrides to call from within python are ignored.
- PyObject_Call checks if vectorcall is available (which is only a property of the type), and if so, then does the vectorcall: https://github.com/python/cpython/blob/3.11/Objects/call.c#L328
- The implementation of vectorcall does not check for tp_call (is this where call ends up?) https://github.com/python/cpython/blob/3.11/Include/internal/pycore_call.h#L39
I encountered this bug when using the latest version of JAX, which introduced vectorcalls in tensorflow/tensorflow@bf3eb11 :
import jax
g = jax.jit(lambda x: x + 1)
print(type(g))
# > <class 'google3.third_party.tensorflow.compiler.xla.python.xla_extension.CompiledFunction'>
print(g(5))
# > 6
type(g).__call__ = lambda *args: 0
print(g(5))
# > 6
Reactions are currently unavailable