@@ -4399,6 +4399,21 @@ class ExternKernelNode:
43994399}
44004400
44014401
4402+ def get_aten_cpp_kernel_name (kernel ):
4403+ # Calling with the default kernel name can lead to ambiguous behavior like the following example.
4404+ # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4405+ # repeat_interleave(const at::Tensor & self, int64_t repeats,
4406+ # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4407+ assert (
4408+ isinstance (kernel , torch ._ops .OpOverload ) and kernel .namespace == "aten"
4409+ ), "Invalid aten kernel"
4410+ return (
4411+ f"at::{ kernel .__name__ .split ('.' )[0 ]} "
4412+ if kernel ._overloadname == "default"
4413+ else f"at::_ops::{ kernel .__name__ .replace ('.' , '_' )} ::call"
4414+ )
4415+
4416+
44024417class FallbackKernel (ExternKernelAlloc ):
44034418 args_default_value : List [Dict [str , Any ]]
44044419
@@ -4655,8 +4670,6 @@ def codegen(self, wrapper):
46554670 if kernel .namespace == "aten" :
46564671 # Aten Fallback Ops
46574672 assert isinstance (kernel , torch ._ops .OpOverload )
4658- op_base_name = kernel .__name__ .split ("." )[0 ]
4659-
46604673 if V .graph .cpp_wrapper :
46614674 if config .is_fbcode () and kernel not in has_c_shim :
46624675 log .warning (
@@ -4666,17 +4679,8 @@ def codegen(self, wrapper):
46664679 self .use_runtime_dispatch = True
46674680 self .set_cpp_kernel (kernel )
46684681 else :
4669- # Calling with the default kernel name can lead to ambiguous behavior like the following example.
4670- # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4671- # repeat_interleave(const at::Tensor & self, int64_t repeats,
4672- # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4673- self .cpp_kernel_name = (
4674- f"at::{ op_base_name } "
4675- if kernel ._overloadname == "default"
4676- else f"at::_ops::{ kernel .__name__ .replace ('.' , '_' )} ::call"
4677- )
4682+ self .cpp_kernel_name = get_aten_cpp_kernel_name (kernel )
46784683 schema = kernel ._schema
4679-
46804684 self .args_default_value = [
46814685 {"type" : x .real_type , "value" : x .default_value }
46824686 for x in schema .arguments
@@ -4691,7 +4695,7 @@ def codegen(self, wrapper):
46914695 if x .kwarg_only
46924696 }
46934697 else :
4694- self .python_kernel_name = f"aten. { op_base_name } "
4698+ self .python_kernel_name = str ( kernel )
46954699
46964700 elif isinstance (kernel , torch ._ops .HigherOrderOperator ):
46974701 if getattr (torch ._prims .rng_prims , kernel .__name__ , None ) is kernel :
@@ -4825,6 +4829,7 @@ def __init__(
48254829 self ,
48264830 layout ,
48274831 python_kernel_name ,
4832+ cpp_kernel_name ,
48284833 tensor_args ,
48294834 nontensor_args ,
48304835 ):
@@ -4838,6 +4843,7 @@ def __init__(
48384843 # output through the abi-compatible interface.
48394844 self .outputs : Sequence [Any ] = []
48404845 self .python_kernel_name = python_kernel_name
4846+ self .cpp_kernel_name = cpp_kernel_name
48414847
48424848 @classmethod
48434849 def create (cls , kernel , * args , ** kwargs ):
@@ -4854,7 +4860,11 @@ def create(cls, kernel, *args, **kwargs):
48544860 assert device , "Not sure where to find device info"
48554861
48564862 packed = ComplexView (
4857- MultiOutputLayout (device ), str (kernel ), tensor_args , non_tensor_args
4863+ MultiOutputLayout (device ),
4864+ str (kernel ),
4865+ get_aten_cpp_kernel_name (kernel ),
4866+ tensor_args ,
4867+ non_tensor_args ,
48584868 )
48594869
48604870 layout = FixedLayout (
0 commit comments