Skip to content

Commit 114aa3c

Browse files
committed
Update on "Making ops c10 full: optional out arguments"
We have some (but very few) ops that take optional out arguments `Tensor(a!)? out`. This PR makes them non-optional mandatory arguments and enables c10-fullness for them. There is only a very small number of ops affected by this. Putting this up for discussion. Alternatives considered: If we keep them optional, we run into lots of issues in the dispatcher. We have to decide what the dispatcher calling convention for this argument type should be. 1) If we keep passing them in as `Tensor&` arguments and return them as `tuple<Tensor&, Tensor&, Tensor&>`, so basically same as currently, then the schema inference check will say "Your kernel function got inferred to have a `Tensor` argument but your native_functions.yaml declaration says `Tensor?`. This is a mismatch, you made an error". We could potentially disable that check, but that would open the door for real mistakes to not be reported anymore in the future. This sounds bad. 2) If we change them to a type that schema inference could differentiate from `Tensor`, say we pass them in as `const optional<Tensor>&` and return them as `tuple<const optional<Tensor>&, const optional<Tensor>&, const optional<Tensor>&>`, then our boxing logic fails because it can't recognize those as out overloads anymore and shortcut the return value as it is doing right now. We might be able to rewrite the boxing logic, but that could be difficult and could easily develop into a rabbit hole of having to clean up `Tensor&` references throughout the system where we use them. Furthermore, having optional out arguments in C++ doesn't really make sense. the C++ API puts them to the front of the argument list, so you can't omit them anyways when calling an op. You would be able to omit them when calling from Python with out kwargs, but not sure if we want that discrepancy between the c++ and python API. Differential Revision: [D25422197](https://our.internmc.facebook.com/intern/diff/D25422197/) [ghstack-poisoned]
2 parents 9950752 + 6b35ba5 commit 114aa3c

10 files changed

Lines changed: 72 additions & 68 deletions

File tree

c10/util/TypeCast.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,5 @@ To checked_convert(From f, const char* name) {
170170
}
171171

172172
} // namespace c10
173+
174+
// Trigger tests for D25440771. TODO: Remove this line any time you want.

c10/util/complex.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ struct alignas(sizeof(T) * 2) complex {
262262
return real() || imag();
263263
}
264264

265-
constexpr T real() const {
265+
C10_HOST_DEVICE constexpr T real() const {
266266
return real_;
267267
}
268268
constexpr void real(T value) {

docs/source/fx.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
.. currentmodule:: torch.fx
2+
3+
torch.fx
4+
=============
5+
6+
Overview
7+
--------
8+
.. automodule:: torch.fx
9+
10+
11+
API Reference
12+
-------------
13+
14+
.. autofunction:: torch.fx.symbolic_trace
15+
16+
.. autoclass:: torch.fx.GraphModule
17+
:members:
18+
19+
.. automethod:: __init__
20+
21+
.. autoclass:: torch.fx.Graph
22+
:members:
23+
24+
.. automethod:: __init__
25+
26+
.. autoclass:: torch.fx.Node
27+
:members:
28+
29+
.. autoclass:: torch.fx.Tracer
30+
:members:

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Features described in this documentation are classified by release status:
6161
torch.distributions <distributions>
6262
torch.fft <fft>
6363
futures
64+
fx
6465
torch.hub <hub>
6566
torch.jit <jit>
6667
torch.linalg <linalg>

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
'distributed/test_distributed_spawn',
4141
'distributions/test_constraints',
4242
'distributions/test_distributions',
43+
'test_dispatch',
4344
'test_expecttest',
4445
'test_foreach',
4546
'test_indexing',

test/test_dispatch.py

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from collections import namedtuple
55
import itertools
6-
import unittest
76
import re
87

98
# TODO: Expand the dispatcher API to be a generic API for interfacing with
@@ -256,7 +255,14 @@ def test_def_impl_schema_mismatch(self):
256255
# m.impl("foo", [](const Tensor & x) { return x })
257256
lambda m: m.impl_t_t("foo"),
258257
], expect_raises=True).state
259-
self.assertExpectedInline(state, '''In registration for test::foo: expected schema of operator to be "test::foo(Tensor x, Tensor y) -> (Tensor)" (registered at /dev/null:0), but got inferred schema "(Tensor _0) -> (Tensor _0)" (impl_t_t). The number of arguments is different. 2 vs 1.''') # noqa
258+
self.assertExpectedInline(state, '''\
259+
Inferred operator schema for a C++ kernel function doesn't match the expected function schema.
260+
operator: test::foo
261+
expected schema: test::foo(Tensor x, Tensor y) -> (Tensor)
262+
registered at /dev/null:0
263+
inferred schema: (Tensor _0) -> (Tensor _0)
264+
impl_t_t
265+
reason: The number of arguments is different. 2 vs 1.''') # noqa
260266

261267
def test_def_with_inference(self):
262268
state = self.commute("foo", [
@@ -656,17 +662,19 @@ def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
656662
AutogradXLA: fn_autograd [autograd kernel]
657663
''')
658664

659-
# Can't do this yet for BC reasons
660-
@unittest.expectedFailure
661665
def test_multiple_def_error(self):
662-
state = self.commute("foo", [
666+
ops = [
663667
# m.def("foo(Tensor x, Tensor y) -> Tensor")
664668
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
665669
# m.def("foo(Tensor x, Tensor y) -> Tensor")
666670
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
667-
], expect_raises=True).state
668-
# TODO: fill in the error message here
669-
# self.assertExpectedInline(state, '''''')
671+
]
672+
self.assertExpectedInline(
673+
self.commute("foo", ops, expect_raises=True).state,
674+
'''Tried to register an operator (test::foo(Tensor x, Tensor y) -> (Tensor)) with the same name and overload '''
675+
'''name multiple times. Each overload's schema should only be registered with a single call to def(). '''
676+
'''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0'''
677+
)
670678

671679
def test_def_with_explicit_alias(self):
672680
state = self.commute("foo", [
@@ -683,52 +691,22 @@ def test_def_with_explicit_alias(self):
683691
alias analysis kind: PURE_FUNCTION
684692
''')
685693

686-
# TODO: get rid of this test when multiple defs are wrong
687-
def test_multiple_def_schema_mismatch(self):
688-
# error message is order dependent
689-
ops = [
690-
# m.def("foo(Tensor x, Tensor y) -> Tensor")
691-
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
692-
# m.def("foo(Tensor x) -> Tensor")
693-
lambda m: m.def_("foo(Tensor x) -> Tensor"),
694-
]
695-
self.assertExpectedInline(
696-
self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state,
697-
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0)''' # noqa
698-
)
699-
self.assertExpectedInline(
700-
self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state,
701-
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0)''' # noqa
702-
)
703-
704694
def test_multiple_def_alias_defaulting(self):
705-
# TODO: should be an error in both directions soon
706695
ops = [
707696
# m.def(torch::schema("foo(Tensor x) -> Tensor",
708697
# c10::AliasAnalysisKind::PURE_FUNCTION))
709698
lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
710699
# RegisterOperators().op("foo(Tensor x) -> Tensor")
711700
lambda m: m.def_legacy("foo(Tensor x) -> Tensor"),
712701
]
713-
state = self.commute("foo", ops, ctor_order=(0, 1)).state
714702
self.assertExpectedInline(
715-
state,
716-
'''\
717-
name: test::foo
718-
schema: test::foo(Tensor x) -> (Tensor)
719-
debug: registered at /dev/null:0
720-
alias analysis kind: PURE_FUNCTION
721-
'''
703+
self.commute("foo", ops, expect_raises=True).state,
704+
'''Tried to register an operator (test::foo(Tensor x) -> (Tensor)) with the same name and overload '''
705+
'''name multiple times. Each overload's schema should only be registered with a single call to def(). '''
706+
'''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0'''
722707
)
723-
# NB: When run with ctor order (1, 0), the destructors are NOT
724-
# COMMUTATIVE. THIS IS A BUG, however we are purposely leaving the bug
725-
# in as it is very benign (only leaves us in a bad state during
726-
# destruction, when no useful work is being done), will be fixed when we
727-
# make alias defaulting a hard error, and is very nontrivial to fix
728-
# prior to that.
729708

730709
def test_multiple_def_alias_mismatch(self):
731-
# error message is order dependent
732710
ops = [
733711
# m.def(torch::schema("foo(Tensor x) -> Tensor",
734712
# c10::AliasAnalysisKind::PURE_FUNCTION))
@@ -738,12 +716,10 @@ def test_multiple_def_alias_mismatch(self):
738716
lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"),
739717
]
740718
self.assertExpectedInline(
741-
self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state,
742-
'''Tried to define the schema for test::foo with different alias analysis kinds: PURE_FUNCTION (registered at /dev/null:0) vs CONSERVATIVE (registered at /dev/null:0)''' # noqa
743-
)
744-
self.assertExpectedInline(
745-
self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state,
746-
'''Tried to define the schema for test::foo with different alias analysis kinds: CONSERVATIVE (registered at /dev/null:0) vs PURE_FUNCTION (registered at /dev/null:0)''' # noqa
719+
self.commute("foo", ops, expect_raises=True).state,
720+
'''Tried to register an operator (test::foo(Tensor x) -> (Tensor)) with the same name and overload '''
721+
'''name multiple times. Each overload's schema should only be registered with a single call to def(). '''
722+
'''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0''' # noqa
747723
)
748724

749725
def test_multiple_fallback(self):
@@ -754,7 +730,8 @@ def test_multiple_fallback(self):
754730
except RuntimeError as e:
755731
self.assertExpectedInline(
756732
str(e),
757-
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa
733+
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration '''
734+
'''registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa
758735
)
759736
else:
760737
self.assertTrue(False)

third_party/tensorpipe

Submodule tensorpipe updated 154 files

third_party/tensorpipe.BUILD

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,7 @@ TENSORPIPE_HEADERS = glob([
9393
TENSORPIPE_BASE_SRCS = glob([
9494
"tensorpipe/*.cc",
9595
"tensorpipe/channel/*.cc",
96-
"tensorpipe/common/address.cc",
97-
"tensorpipe/common/epoll_loop.cc",
98-
"tensorpipe/common/error.cc",
99-
"tensorpipe/common/fd.cc",
100-
"tensorpipe/common/ibv.cc",
101-
"tensorpipe/common/socket.cc",
102-
"tensorpipe/common/system.cc",
96+
"tensorpipe/common/*.cc",
10397
"tensorpipe/core/*.cc",
10498
"tensorpipe/transport/*.cc",
10599
"tensorpipe/util/*/*.cc",
@@ -113,10 +107,7 @@ TENSORPIPE_SRCS = TENSORPIPE_BASE_SRCS + glob([
113107
])
114108

115109
TENSORPIPE_SRCS_CUDA = TENSORPIPE_SRCS + glob([
116-
"tensorpipe/common/cuda_loop.cc",
117-
"tensorpipe/channel/cuda_basic/*.cc",
118110
"tensorpipe/channel/cuda_ipc/*.cc",
119-
"tensorpipe/channel/cuda_xth/*.cc",
120111
])
121112

122113
cc_library(

torch/fx/graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def call_method(self,
415415
type_expr: Optional[Any] = None) -> Node:
416416
"""
417417
Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
418-
represents a call to a given method on the 0th element of `args.
418+
represents a call to a given method on the 0th element of ``args``.
419419
420420
Args:
421421
@@ -756,9 +756,9 @@ def lint(self, root : Optional[torch.nn.Module] = None):
756756
"""
757757
Runs various checks on this Graph to make sure it is well-formed. In
758758
particular:
759-
- Checks Nodes have correct ownership (owned by this graph)
760-
- Checks Nodes appear in topological order
761-
- If ``root`` is provided, checks that targets exist in ``root``
759+
- Checks Nodes have correct ownership (owned by this graph)
760+
- Checks Nodes appear in topological order
761+
- If ``root`` is provided, checks that targets exist in ``root``
762762
763763
Args:
764764

torch/fx/symbolic_trace.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def create_arg(self, a: Any) -> 'Argument':
6666
* For a Parameter, emit a ``get_attr`` node referring to that Parameter
6767
* For a non-Parameter Tensor, store the Tensor away in a special
6868
attribute referring to that attribute.
69+
6970
This method can be overridden to support more types.
7071
7172
Args:
@@ -126,6 +127,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
126127
via this parameter.
127128
128129
Args:
130+
129131
m (Module): The module being queried about
130132
module_qualified_name (str): The path to root of this module. For example,
131133
if you have a module hierarchy where submodule ``foo`` contains
@@ -163,7 +165,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu
163165
This method can be overridden to--for example--create nested traced
164166
GraphModules, or any other behavior you would want while tracing across
165167
``Module`` boundaries.
166-
``Module`` boundaries.
168+
``Module`` boundaries.
167169
168170
Args:
169171
@@ -186,8 +188,8 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu
186188
def create_args_for_root(self, root_fn, is_module):
187189
"""
188190
Create ``placeholder`` nodes corresponding to the signature of the ``root``
189-
Module. This method introspects ``root``'s signature and emits those
190-
nodes accordingly, also supporting *args and **kwargs.
191+
Module. This method introspects root's signature and emits those
192+
nodes accordingly, also supporting ``*args`` and ``**kwargs``.
191193
"""
192194
# In some cases, a function or method has been decorated with a wrapper
193195
# defined via ``functools.wraps``. In this case, the outer code object

0 commit comments

Comments
 (0)