You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "Making ops c10 full: optional out arguments"
We have some (but very few) ops that take optional out arguments `Tensor(a!)? out`.
This PR makes them non-optional mandatory arguments and enables c10-fullness for them.
There is only a very small number of ops affected by this.
Putting this up for discussion.
Alternatives considered:
If we keep them optional, we run into lots of issues in the dispatcher. We have to decide what the dispatcher calling convention for this argument type should be.
1) If we keep passing them in as `Tensor&` arguments and return them as `tuple<Tensor&, Tensor&, Tensor&>`, so basically same as currently, then the schema inference check will say "Your kernel function got inferred to have a `Tensor` argument but your native_functions.yaml declaration says `Tensor?`. This is a mismatch, you made an error". We could potentially disable that check, but that would open the door for real mistakes to not be reported anymore in the future. This sounds bad.
2) If we change them to a type that schema inference could differentiate from `Tensor`, say we pass them in as `const optional<Tensor>&` and return them as `tuple<const optional<Tensor>&, const optional<Tensor>&, const optional<Tensor>&>`, then our boxing logic fails because it can't recognize those as out overloads anymore and shortcut the return value as it is doing right now. We might be able to rewrite the boxing logic, but that could be difficult and could easily develop into a rabbit hole of having to clean up `Tensor&` references throughout the system where we use them.
Furthermore, having optional out arguments in C++ doesn't really make sense. the C++ API puts them to the front of the argument list, so you can't omit them anyways when calling an op.
You would be able to omit them when calling from Python with out kwargs, but not sure if we want that discrepancy between the c++ and python API.
Differential Revision: [D25422197](https://our.internmc.facebook.com/intern/diff/D25422197/)
[ghstack-poisoned]
# m.impl("foo", [](const Tensor & x) { return x })
257
256
lambdam: m.impl_t_t("foo"),
258
257
], expect_raises=True).state
259
-
self.assertExpectedInline(state, '''In registration for test::foo: expected schema of operator to be "test::foo(Tensor x, Tensor y) -> (Tensor)" (registered at /dev/null:0), but got inferred schema "(Tensor _0) -> (Tensor _0)" (impl_t_t). The number of arguments is different. 2 vs 1.''') # noqa
258
+
self.assertExpectedInline(state, '''\
259
+
Inferred operator schema for a C++ kernel function doesn't match the expected function schema.
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0)'''# noqa
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0)'''# noqa
702
-
)
703
-
704
694
deftest_multiple_def_alias_defaulting(self):
705
-
# TODO: should be an error in both directions soon
'''Tried to define the schema for test::foo with different alias analysis kinds: PURE_FUNCTION (registered at /dev/null:0) vs CONSERVATIVE (registered at /dev/null:0)'''# noqa
'''Tried to define the schema for test::foo with different alias analysis kinds: CONSERVATIVE (registered at /dev/null:0) vs PURE_FUNCTION (registered at /dev/null:0)'''# noqa
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0'''# noqa
733
+
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration '''
734
+
'''registered at /dev/null:0, new registration registered at /dev/null:0'''# noqa
0 commit comments