@@ -42,21 +42,18 @@ using supported_primitive_arg_types = guts::typelist::typelist<
4242 at::Tensor,
4343 at::Scalar,
4444 c10::QScheme,
45- c10::ScalarType>;
46-
47- template <class T , bool AllowDeprecatedTypes, class Enable = void >
48- struct assert_is_valid_input_type {
49- assert_is_valid_input_type () {
50- auto tmap = c10::getCustomClassTypeMap ();
51- TORCH_CHECK (
52- c10::isCustomClassRegistered<T>(),
53- " Tried to use undefined class as input argument" );
54- }
55- };
56-
57- template <class T , bool AllowDeprecatedTypes>
58- struct assert_is_valid_input_type <T, AllowDeprecatedTypes, std::enable_if_t <guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
59- // everything is ok, this is a primitive type
45+ c10::ScalarType
46+ >;
47+
48+ template <class T , bool AllowDeprecatedTypes, class Enable = void > struct assert_is_valid_input_type {
49+ assert_is_valid_input_type () {
50+ guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
51+ /* everything is ok, this is a primitive type */
52+ }, /* else */ [] {
53+ auto tmap = c10::getCustomClassTypeMap ();
54+ TORCH_CHECK (c10::isCustomClassRegistered<T>(), " Tried to use undefined class as input argument" );
55+ });
56+ }
6057 };
6158
6259 template <class T , bool AllowDeprecatedTypes>
@@ -139,16 +136,15 @@ struct assert_is_valid_input_type {
139136
140137 template <class T , bool AllowDeprecatedTypes, class Enable = void > struct assert_is_valid_output_type {
141138 assert_is_valid_output_type () {
142- auto tmap = getCustomClassTypeMap ();
143- TORCH_CHECK (c10::isCustomClassRegistered<T>(), " Tried to use undefined class as output" );
139+ guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
140+ /* everything is ok, this is a primitive type */
141+ }, /* else */ [] {
142+ auto tmap = getCustomClassTypeMap ();
143+ TORCH_CHECK (c10::isCustomClassRegistered<T>(), " Tried to use undefined class as output" );
144+ });
144145 }
145146 };
146147
147- template <class T , bool AllowDeprecatedTypes>
148- struct assert_is_valid_output_type <T, AllowDeprecatedTypes, std::enable_if_t <guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
149- // everything is ok, this is a primitive type
150- };
151-
152148 template <class T , bool AllowDeprecatedTypes>
153149 struct assert_is_valid_output_type <c10::optional<T>, AllowDeprecatedTypes>
154150 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
@@ -268,33 +264,30 @@ struct assert_is_valid_input_type {
268264 torch::jit::push (*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>(std::move (std::get<indices>(output)))...);
269265 }
270266 };
271-
272- template <class KernelFunctor , bool AllowDeprecatedTypes, class Enable = void > struct make_boxed_from_unboxed_functor final {};
273-
274- // SFINAE version for kernels that return an output
275- template <class KernelFunctor , bool AllowDeprecatedTypes>
276- struct make_boxed_from_unboxed_functor <KernelFunctor, AllowDeprecatedTypes, std::enable_if_t <!std::is_same<void , typename guts::infer_function_traits_t <KernelFunctor>::return_type>::value>> final {
277- static_assert (std::is_base_of<OperatorKernel, KernelFunctor>::value, " Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." );
278-
279- static void call (OperatorKernel* functor, const OperatorHandle&, Stack* stack) {
280- constexpr size_t num_inputs = guts::infer_function_traits_t <KernelFunctor>::number_of_parameters;
281- KernelFunctor* functor_ = static_cast <KernelFunctor*>(functor);
282- auto output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
283- torch::jit::drop (*stack, num_inputs);
284- push_outputs<typename guts::infer_function_traits_t <KernelFunctor>::return_type, AllowDeprecatedTypes>::call (std::move (output), stack);
267+ template <bool AllowDeprecatedTypes>
268+ struct push_outputs <void , AllowDeprecatedTypes> final {
269+ static void call (int /* dummy*/ , Stack* /* stack*/ ) {
285270 }
286271 };
287272
288- // SFINAE version for kernels that don't return an output
289273 template <class KernelFunctor , bool AllowDeprecatedTypes>
290- struct make_boxed_from_unboxed_functor <KernelFunctor, AllowDeprecatedTypes, std:: enable_if_t <std::is_same< void , typename guts:: infer_function_traits_t <KernelFunctor>::return_type>::value>> final {
274+ struct make_boxed_from_unboxed_functor final {
291275 static_assert (std::is_base_of<OperatorKernel, KernelFunctor>::value, " Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." );
292276
293277 static void call (OperatorKernel* functor, const OperatorHandle&, Stack* stack) {
294278 constexpr size_t num_inputs = guts::infer_function_traits_t <KernelFunctor>::number_of_parameters;
295279 KernelFunctor* functor_ = static_cast <KernelFunctor*>(functor);
296- call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
297- torch::jit::pop (*stack, num_inputs);
280+
281+ using ReturnType = typename guts::infer_function_traits_t <KernelFunctor>::return_type;
282+ constexpr bool has_outputs = !std::is_same<void , ReturnType>::value;
283+ guts::if_constexpr<has_outputs>([&] (auto _) {
284+ auto output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, _ (stack));
285+ torch::jit::drop (*stack, num_inputs);
286+ push_outputs<ReturnType, AllowDeprecatedTypes>::call (std::move (output), stack);
287+ }, /* else */ [&] {
288+ call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
289+ torch::jit::drop (*stack, num_inputs);
290+ });
298291 }
299292 };
300293
0 commit comments