Skip to content

Replicate old generated_jit behavior on basis of overload for only-jitted variants of functions #8897

@shaunc

Description

@shaunc

Feature request

The @overload functionality that is meant to replace @generated_jit focuses on the case of existing python functions for which we would like versions callable within jit. This is suboptimal for the case where we would like to provide variants of a function that should also be jitted when called from python.

In this case, we need to provide a python "stub" for the function, which, by default, would be called by python code. If we want to call njitted version from python code, the workaround is to use a stub with a private name, then define an additional njitted wrapper function which prima facie calls the private function; this call will then be replaced by the overload.

cf https://numba.discourse.group/t/replacing-generated-jit-with-overload/1890/3?u=shaunc

This pattern generates a lot of cruft, and may make such functions harder to maintain. I wrote up this replacement for generated_jit which seems to work in a simple test. Can this (or a more sophisticated version) be used as a replacement for @generated_jit?

from functools import wraps
from typing import Callable, TypeVar, Any

import numba as nb  # type: ignore
from numba import extending as nbe  # type: ignore

T = TypeVar("T")


def generated_jit(
    jit_options: dict[str, Any] = {},
    strict: bool = True,
    inline: str = "never",
    prefer_literal: bool = False,
    cache: bool = False,
    **kw: Any,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
    """
    Decorator to provide variant JIT-compiled functions for argument types.

    Uses numba `@overload` to simulate previous `@generated_jit` behavior. @overload
    is focused on providing jitted compilations of python functions. This
    decorator allows us to use `@overload` to provide jitted functions
    where the call from python should also call a jitted function.
    """

    def _decorator(func: Callable[..., T]) -> Callable[..., T]:
        def _stub(*args: Any) -> T:
            raise NotImplementedError("should not be called")

        jit_options.setdefault("cache", cache)
        jit_options.setdefault("inline", inline)
        nbe.overload(
            wraps(func)(_stub),
            jit_options=jit_options,
            strict=strict,
            inline=inline,
            prefer_literal=prefer_literal,
            **kw,
        )(func)

        @nb.njit(**jit_options)  # type: ignore
        def _jit(*args: Any) -> T:
            return _stub(*args)

        return wraps(func)(_jit)  # type: ignore

    return _decorator

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs triage/maintainer discussionFor an issue/PR that needs discussion at a triage/maintainer meetingquestionNotes an issue as a questionstaleMarker label for stale issues.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions