|
6 | 6 | from tools.codegen.context import method_with_native_function |
7 | 7 | from tools.codegen.utils import Target, mapMaybe |
8 | 8 | from tools.codegen.model import (Argument, ExternalBackendFunction, |
9 | | - ExternalBackendFunctionsGroup, |
| 9 | + ExternalBackendFunctionsGroup, SchemaKind, |
10 | 10 | assert_never, Return, is_generic_dispatch_key, |
11 | 11 | ListType, OptionalType, BaseType, BaseTy, Variant) |
12 | 12 | from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup |
|
66 | 66 | 'var', |
67 | 67 | ] |
68 | 68 |
|
| 69 | +# TODO: delete and re-use the one from gen.py |
| 70 | +def has_autogenerated_composite_kernel(f: ExternalBackendFunction) -> bool: |
| 71 | + return (f.native_function.structured or f.native_function.structured_delegate is not None) and \ |
| 72 | + (f.native_function.func.kind() == SchemaKind.functional or f.native_function.func.kind() == SchemaKind.inplace) |
| 73 | + |
69 | 74 | def requires_backend_wrapper(f: ExternalBackendFunction) -> bool: |
70 | | - requires_lowering = not any(is_generic_dispatch_key(k) for k in f.native_function.dispatch) |
| 75 | + requires_lowering = not any(is_generic_dispatch_key(k) for k in f.native_function.dispatch) \ |
| 76 | + and not has_autogenerated_composite_kernel(f) |
71 | 77 | has_xla_lowering = f.metadata is not None |
72 | 78 | in_denylist = any([re.match(frx, str(f.native_function.func.name)) for frx in _FN_DENYLIST_REGEX]) |
73 | 79 | return not in_denylist and (requires_lowering or has_xla_lowering) |
|
0 commit comments