Skip to content

Commit d705083

Browse files
ezyangfacebook-github-bot
authored andcommitted
Refactor dispatcher and native to use Signature structure. (#45990)
Summary: Pull Request resolved: #45990 In #45890 we introduced the concept of a CppSignature, which bundled up all of the information necessary to declare a C++ signature for the cpp API. This PR introduces analogous concepts for dispatcher and native: DispatcherSignature and NativeSignature. The three interfaces are not particularly well coupled right now, but they do have some duck typing coincidences: - defn() which renders the C++ definition "bool f(int x)" - decl() which renders the C++ declaration "bool f(int x = 2)" - type() which renders the C++ function type "bool(int)" Maybe at some point we'll introduce a Protocol, or a supertype. Many other methods (like arguments()) have varying types. These signatures also have some helper methods that forward back to real implementations in the api modules. Something to think about is whether or not we should attempt to reduce boilerplate here or not; I'm not too sure about it yet. The net effect is we get to reduce the number of variables we have to explicitly write out in the codegen, since now these are all bundled together into a signature. Something extra special happens in BackendSelect, where we now dynamically select between dispatcher_sig and native_sig as "how" the backend select is implemented. A little bit of extra cleanup: - Some places where we previously advertised Sequence, we now advertise a more informative Tuple. - defn() may take an optional positional parameter overriding the entire name, or a kwarg-only prefix parameter to just add a prefix to the name. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: smessmer Differential Revision: D24223100 Pulled By: ezyang fbshipit-source-id: f985eced08af4a60ba9641d125d0f260f8cda9eb
1 parent f086032 commit d705083

4 files changed

Lines changed: 109 additions & 39 deletions

File tree

tools/codegen/api/dispatcher.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tools.codegen.local as local
77

88
import itertools
9-
from typing import Sequence, Optional
9+
from typing import Sequence, Optional, Tuple
1010

1111
# This file describes the translation of JIT schema to the dispatcher
1212
# API, the *unboxed* calling convention by which invocations through
@@ -65,14 +65,14 @@ def argument(a: Argument) -> DispatcherArgument:
6565
def name(func: FunctionSchema) -> str:
6666
return cpp.name(func)
6767

68-
def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
68+
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
6969
if local.use_c10_dispatcher().dispatcher_uses_new_style():
70-
return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
70+
return tuple(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
7171
else:
72-
return [
72+
return tuple(
7373
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
7474
for la in native.arguments(func)
75-
]
75+
)
7676

7777
# Given a set of CppArguments in scope, return a sequence of dispatcher
7878
# expressions that translate the cpp API into dispatcher API

tools/codegen/api/native.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tools.codegen.api.types import TensorOptionsArguments, NativeArgument, ThisArgument
44
import tools.codegen.api.cpp as cpp
55

6-
from typing import Union, Sequence
6+
from typing import Union, Sequence, Tuple
77

88
# This file describes the translation of JIT schema to the native functions API.
99
# This looks a lot like the C++ API (which makes historical sense, because the
@@ -74,5 +74,5 @@ def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> Native
7474
else:
7575
assert_never(a)
7676

77-
def arguments(func: FunctionSchema) -> Sequence[NativeArgument]:
78-
return list(map(argument, cpp.group_arguments(func, method=False)))
77+
def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]:
78+
return tuple(map(argument, cpp.group_arguments(func, method=False)))

tools/codegen/api/types.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,11 @@ def decl(self) -> str:
198198

199199
# Render the C++ definition for this signature, not including
200200
# the body (with curly braces)
201-
def defn(self, prefix: str = "") -> str:
201+
def defn(self, name: Optional[str] = None, *, prefix: str = "") -> str:
202202
cpp_args_str = ', '.join(a.str_no_default() for a in self.arguments())
203-
return f"{self._returns_type} {prefix}{cpp.name(self.func)}({cpp_args_str})"
203+
if name is None:
204+
name = prefix + cpp.name(self.func)
205+
return f"{self._returns_type} {name}({cpp_args_str})"
204206

205207
# NB: This constructor knows how to disambiguate defaults when
206208
# faithful is True. Ideally this would live as an external process
@@ -280,6 +282,47 @@ class DispatcherArgument:
280282
def __str__(self) -> str:
281283
return f"{self.type} {self.name}"
282284

285+
@dataclass(frozen=True)
286+
class DispatcherSignature:
287+
# The schema this signature is derived from
288+
func: FunctionSchema
289+
290+
# Note to self: if we ever need to reassemble tensor options, we may need to
291+
# also preserve grouping with DispatcherTensorOptionsArguments. This should
292+
# be an unlikely situation, however, since the general direction we are
293+
# headed is to make native:: take everything in expanded form, so you
294+
# shouldn't need to reassemble
295+
_arguments: Tuple[DispatcherArgument, ...]
296+
_returns_type: str
297+
298+
def arguments(self) -> Tuple[DispatcherArgument, ...]:
299+
return self._arguments
300+
301+
def defn(self, name: Optional[str] = None) -> str:
302+
args_str = ', '.join(map(str, self.arguments()))
303+
if name is None:
304+
name = native.name(self.func)
305+
return f"{self._returns_type} {name}({args_str})"
306+
307+
def exprs(self) -> Sequence[DispatcherExpr]:
308+
return dispatcher.exprs(self.arguments())
309+
310+
# Return the C++ function type, e.g., something like int(bool)
311+
def type(self) -> str:
312+
dispatcher_args_types_str = ', '.join(map(lambda a: a.type, self._arguments))
313+
return f'{self._returns_type} ({dispatcher_args_types_str})'
314+
315+
@staticmethod
316+
def from_schema(func: FunctionSchema) -> 'DispatcherSignature':
317+
arguments = dispatcher.arguments(func)
318+
returns_type = dispatcher.returns_type(func.returns)
319+
320+
return DispatcherSignature(
321+
func=func,
322+
_arguments=arguments,
323+
_returns_type=returns_type,
324+
)
325+
283326
# ------------------------------------------------------------------- #
284327

285328
# native types (NativeFunctions.h)
@@ -320,5 +363,36 @@ def str_with_default(self) -> str:
320363
mb_default = f"={self.default}"
321364
return f"{self.type} {self.name}{mb_default}"
322365

366+
@dataclass(frozen=True)
367+
class NativeSignature:
368+
# The schema this signature is derived from
369+
func: FunctionSchema
370+
371+
_arguments: Tuple[NativeArgument, ...]
372+
_returns_type: str
373+
374+
def defn(self, name: Optional[str] = None) -> str:
375+
args_str = ', '.join(map(str, self.arguments()))
376+
if name is None:
377+
name = dispatcher.name(self.func)
378+
return f"{self._returns_type} {name}({args_str})"
379+
380+
def arguments(self) -> Tuple[NativeArgument, ...]:
381+
return self._arguments
382+
383+
def dispatcher_exprs(self) -> Sequence['DispatcherExpr']:
384+
return dispatcher.nativearguments_exprs(self.arguments())
385+
386+
@staticmethod
387+
def from_schema(func: FunctionSchema) -> 'NativeSignature':
388+
arguments = native.arguments(func)
389+
returns_type = native.returns_type(func.returns)
390+
391+
return NativeSignature(
392+
func=func,
393+
_arguments=arguments,
394+
_returns_type=returns_type,
395+
)
396+
323397
# Functions only, no types
324-
import tools.codegen.api.cpp as cpp
398+
from tools.codegen.api import cpp, dispatcher, native

tools/codegen/gen.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,8 @@ def func(f: NativeFunction) -> Optional[str]:
269269
"""
270270

271271
elif target is Target.REGISTRATION:
272-
assert returns_type == dispatcher.returns_type(f.func.returns)
273-
dispatcher_args = dispatcher.arguments(f.func)
274-
dispatcher_args_types_str = ', '.join(map(lambda a: a.type, dispatcher_args))
272+
dispatcher_sig = DispatcherSignature.from_schema(f.func)
273+
275274
if dispatch is None or dispatch == 'Math' or dispatch == 'DefaultBackend':
276275
type_name = f'TypeDefault::{name}'
277276
else:
@@ -289,7 +288,8 @@ def func(f: NativeFunction) -> Optional[str]:
289288
payload = f"TORCH_FN({type_name})"
290289
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
291290
payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
292-
f"{returns_type} ({dispatcher_args_types_str})>(TORCH_FN({type_name}))"
291+
f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))"
292+
293293
else:
294294
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
295295
payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})"
@@ -338,17 +338,17 @@ def go(f: NativeFunction) -> Optional[str]:
338338
assert target is Target.DEFINITION
339339

340340
def generate_defn(sig: CppSignature) -> str:
341+
dispatcher_sig = DispatcherSignature.from_schema(f.func)
342+
341343
dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs())
342-
dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
343-
dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs))
344344
dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs))
345345

346346
return f"""
347347
// aten::{f.func}
348348
{sig.defn()} {{
349349
static auto op = c10::Dispatcher::singleton()
350350
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
351-
.typed<{dispatcher_returns_type} ({dispatcher_types_str})>();
351+
.typed<{dispatcher_sig.type()}>();
352352
return op.call({dispatcher_exprs_str});
353353
}}
354354
"""
@@ -388,17 +388,17 @@ def go(f: NativeFunction) -> Optional[str]:
388388
assert target is Target.DEFINITION
389389

390390
def generate_defn(sig: CppSignature) -> str:
391+
dispatcher_sig = DispatcherSignature.from_schema(f.func)
392+
391393
dispatcher_exprs = dispatcher.cpparguments_exprs(sig.argument_packs())
392-
dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
393-
dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs))
394394
dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs))
395395

396396
return f"""
397397
// aten::{f.func}
398-
{sig.defn("Tensor::")} const {{
398+
{sig.defn(prefix="Tensor::")} const {{
399399
static auto op = c10::Dispatcher::singleton()
400400
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
401-
.typed<{dispatcher_returns_type} ({dispatcher_types_str})>();
401+
.typed<{dispatcher_sig.type()}>();
402402
return op.call({dispatcher_exprs_str});
403403
}}
404404
"""
@@ -455,30 +455,26 @@ def go(f: NativeFunction) -> Optional[str]:
455455
return None
456456

457457
name = native.name(f.func)
458-
native_returns_type = native.returns_type(f.func.returns)
459-
native_args = native.arguments(f.func)
458+
native_sig = NativeSignature.from_schema(f.func)
460459

461-
if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_args):
460+
if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()):
462461
return None
463462

464463
native_tensor_args = [
465-
a for a in native_args
464+
a for a in native_sig.arguments()
466465
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
467466
]
468467

469-
dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
470-
dispatcher_args = dispatcher.arguments(f.func)
468+
dispatcher_sig = DispatcherSignature.from_schema(f.func)
471469

472-
args: Union[Sequence[DispatcherArgument], Sequence[NativeArgument]]
470+
sig: Union[NativeSignature, DispatcherSignature]
473471
if local.use_c10_dispatcher().dispatcher_uses_new_style():
474-
returns_type = dispatcher_returns_type
475-
args = dispatcher_args
476-
exprs = dispatcher.exprs(dispatcher_args)
472+
sig = dispatcher_sig
473+
dispatcher_exprs = dispatcher_sig.exprs()
477474
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
478475
else:
479-
returns_type = native_returns_type
480-
args = native_args
481-
exprs = dispatcher.nativearguments_exprs(native_args)
476+
sig = native_sig
477+
dispatcher_exprs = native_sig.dispatcher_exprs()
482478
dispatch_key = "options.computeDispatchKey()"
483479

484480
if target is Target.DEFINITION:
@@ -496,24 +492,24 @@ def go(f: NativeFunction) -> Optional[str]:
496492
compute_dk = f"DispatchKey _dk = {dispatch_key};"
497493
return f"""\
498494
// aten::{f.func}
499-
{returns_type} {name}({', '.join(str(a) for a in args)}) {{
495+
{sig.defn(name)} {{
500496
static auto op = c10::Dispatcher::singleton()
501497
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
502-
.typed<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>();
498+
.typed<{dispatcher_sig.type()}>();
503499
{compute_dk}
504500
DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk);
505501
// This trick allows calling Autograd backend kernel first and then backend kernel,
506502
// without adding another AutogradBackendSelect dispatch key.
507503
DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk;
508-
return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in exprs)});
504+
return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)});
509505
}}
510506
"""
511507
elif target is Target.REGISTRATION:
512508
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
513509
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
514510
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
515511
return f"""m.impl("aten::{f.func.name}",
516-
c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>(
512+
c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_sig.type()}>(
517513
TORCH_FN({name})));"""
518514
else:
519515
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper

0 commit comments

Comments
 (0)