@@ -580,6 +580,38 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
580580
581581#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 ( \
582582 SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
583+ [&] { \
584+ switch (TYPE) { \
585+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Byte, uint8_t , __VA_ARGS__) \
586+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Char, int8_t , __VA_ARGS__) \
587+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Double, double , __VA_ARGS__) \
588+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Float, float , __VA_ARGS__) \
589+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Int, int32_t , __VA_ARGS__) \
590+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Long, int64_t , __VA_ARGS__) \
591+ AT_PRIVATE_CASE_TYPE (at::ScalarType::Short, int16_t , __VA_ARGS__) \
592+ AT_PRIVATE_CASE_TYPE ( \
593+ at::ScalarType::ComplexFloat, std::complex <float >, __VA_ARGS__) \
594+ AT_PRIVATE_CASE_TYPE ( \
595+ at::ScalarType::ComplexDouble, std::complex <double >, __VA_ARGS__) \
596+ AT_PRIVATE_CASE_TYPE ( \
597+ SCALARTYPE1, \
598+ decltype (c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
599+ __VA_ARGS__) \
600+ AT_PRIVATE_CASE_TYPE ( \
601+ SCALARTYPE2, \
602+ decltype (c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
603+ __VA_ARGS__) \
604+ AT_PRIVATE_CASE_TYPE ( \
605+ SCALARTYPE3, \
606+ decltype (c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
607+ __VA_ARGS__) \
608+ default : \
609+ AT_ERROR (#NAME, " not implemented for '" , TYPE, " '" ); \
610+ } \
611+ }()
612+
613+ #define AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3 ( \
614+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
583615 [&] { \
584616 switch (TYPE) { \
585617 AT_PRIVATE_CASE_TYPE (at::ScalarType::Byte, uint8_t , __VA_ARGS__) \
0 commit comments