Feature request
The @overload functionality that is meant to replace @generated_jit focuses on the case of existing python functions for which we would like versions callable within jit. This is suboptimal for the case where we would like to provide variants of a function that should also be jitted when called from python.
In this case, we need to provide a python "stub" for the function, which, by default, would be called by python code. If we want to call njitted version from python code, the workaround is to use a stub with a private name, then define an additional njitted wrapper function which prima facie calls the private function; this call will then be replaced by the overload.
cf https://numba.discourse.group/t/replacing-generated-jit-with-overload/1890/3?u=shaunc
This pattern generates a lot of cruft, and may make such functions harder to maintain. I wrote up this replacement for generated_jit which seems to work in a simple test. Can this (or a more sophisticated version) be used as a replacement for @generated_jit?
from functools import wraps
from typing import Callable, TypeVar, Any
import numba as nb # type: ignore
from numba import extending as nbe # type: ignore
T = TypeVar("T")
def generated_jit(
jit_options: dict[str, Any] = {},
strict: bool = True,
inline: str = "never",
prefer_literal: bool = False,
cache: bool = False,
**kw: Any,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Decorator to provide variant JIT-compiled functions for argument types.
Uses numba `@overload` to simulate previous `@generated_jit` behavior. @overload
is focused on providing jitted compilations of python functions. This
decorator allows us to use `@overload` to provide jitted functions
where the call from python should also call a jitted function.
"""
def _decorator(func: Callable[..., T]) -> Callable[..., T]:
def _stub(*args: Any) -> T:
raise NotImplementedError("should not be called")
jit_options.setdefault("cache", cache)
jit_options.setdefault("inline", inline)
nbe.overload(
wraps(func)(_stub),
jit_options=jit_options,
strict=strict,
inline=inline,
prefer_literal=prefer_literal,
**kw,
)(func)
@nb.njit(**jit_options) # type: ignore
def _jit(*args: Any) -> T:
return _stub(*args)
return wraps(func)(_jit) # type: ignore
return _decorator
Feature request
The
@overloadfunctionality that is meant to replace@generated_jitfocuses on the case of existing python functions for which we would like versions callable within jit. This is suboptimal for the case where we would like to provide variants of a function that should also be jitted when called from python.In this case, we need to provide a python "stub" for the function, which, by default, would be called by python code. If we want to call njitted version from python code, the workaround is to use a stub with a private name, then define an additional njitted wrapper function which prima facie calls the private function; this call will then be replaced by the overload.
cf https://numba.discourse.group/t/replacing-generated-jit-with-overload/1890/3?u=shaunc
This pattern generates a lot of cruft, and may make such functions harder to maintain. I wrote up this replacement for
generated_jitwhich seems to work in a simple test. Can this (or a more sophisticated version) be used as a replacement for @generated_jit?