Skip to content

Cannot patch __call__ on an object with vectorcall defined #101497

@davidsaxton

Description

@davidsaxton

It appears that when the type of an object has JAX_TPFLAGS_HAVE_VECTORCALL set on tp_flags, then overrides to call from within python are ignored.

I encountered this bug when using the latest version of JAX, which introduced vectorcalls in tensorflow/tensorflow@bf3eb11 :

import jax
g = jax.jit(lambda x: x + 1)
print(type(g))
# > <class 'google3.third_party.tensorflow.compiler.xla.python.xla_extension.CompiledFunction'>
print(g(5))
# > 6
type(g).__call__ = lambda *args: 0
print(g(5))
# > 6

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions