Skip to content

Commit a81edf9

Browse files
desertfirepytorchmergebot
authored andcommitted
[inductor] Fix cpp_wrapper codegen for ir.ComplexView (#116481)
Pull Request resolved: #116481 Approved by: https://github.com/htyu
1 parent b0629cd commit a81edf9

4 files changed

Lines changed: 29 additions & 15 deletions

File tree

test/inductor/test_cpu_cpp_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class BaseTest(NamedTuple):
133133
code_string_count: dict = {}
134134

135135
for item in [
136+
BaseTest("test_add_complex2"),
136137
BaseTest("test_as_strided"), # buffer reuse
137138
BaseTest("test_bernoulli1"),
138139
BaseTest("test_bitwise"), # int32

test/inductor/test_cuda_cpp_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class BaseTest(NamedTuple):
151151

152152
# Maintain two separate test lists for cuda and cpp for now
153153
for item in [
154+
BaseTest("test_add_complex2"),
154155
BaseTest("test_as_strided"), # buffer reuse
155156
BaseTest("test_batch_norm_2d_2"),
156157
BaseTest("test_bernoulli1"),

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,9 @@ def fn(a, b):
670670
y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
671671

672672
_, code = run_and_get_code(fn, x, y)
673-
self.assertEqual(code[0].count("aten.view"), 3)
673+
self.assertEqual(
674+
code[0].count("::view_dtype" if config.cpp_wrapper else "aten.view"), 3
675+
)
674676

675677
def test_add_complex3(self):
676678
# fix https://github.com/pytorch/pytorch/issues/115071

torch/_inductor/ir.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4399,6 +4399,21 @@ class ExternKernelNode:
43994399
}
44004400

44014401

4402+
def get_aten_cpp_kernel_name(kernel):
4403+
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
4404+
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4405+
# repeat_interleave(const at::Tensor & self, int64_t repeats,
4406+
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4407+
assert (
4408+
isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten"
4409+
), "Invalid aten kernel"
4410+
return (
4411+
f"at::{kernel.__name__.split('.')[0]}"
4412+
if kernel._overloadname == "default"
4413+
else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
4414+
)
4415+
4416+
44024417
class FallbackKernel(ExternKernelAlloc):
44034418
args_default_value: List[Dict[str, Any]]
44044419

@@ -4655,8 +4670,6 @@ def codegen(self, wrapper):
46554670
if kernel.namespace == "aten":
46564671
# Aten Fallback Ops
46574672
assert isinstance(kernel, torch._ops.OpOverload)
4658-
op_base_name = kernel.__name__.split(".")[0]
4659-
46604673
if V.graph.cpp_wrapper:
46614674
if config.is_fbcode() and kernel not in has_c_shim:
46624675
log.warning(
@@ -4666,17 +4679,8 @@ def codegen(self, wrapper):
46664679
self.use_runtime_dispatch = True
46674680
self.set_cpp_kernel(kernel)
46684681
else:
4669-
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
4670-
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4671-
# repeat_interleave(const at::Tensor & self, int64_t repeats,
4672-
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4673-
self.cpp_kernel_name = (
4674-
f"at::{op_base_name}"
4675-
if kernel._overloadname == "default"
4676-
else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
4677-
)
4682+
self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel)
46784683
schema = kernel._schema
4679-
46804684
self.args_default_value = [
46814685
{"type": x.real_type, "value": x.default_value}
46824686
for x in schema.arguments
@@ -4691,7 +4695,7 @@ def codegen(self, wrapper):
46914695
if x.kwarg_only
46924696
}
46934697
else:
4694-
self.python_kernel_name = f"aten.{op_base_name}"
4698+
self.python_kernel_name = str(kernel)
46954699

46964700
elif isinstance(kernel, torch._ops.HigherOrderOperator):
46974701
if getattr(torch._prims.rng_prims, kernel.__name__, None) is kernel:
@@ -4825,6 +4829,7 @@ def __init__(
48254829
self,
48264830
layout,
48274831
python_kernel_name,
4832+
cpp_kernel_name,
48284833
tensor_args,
48294834
nontensor_args,
48304835
):
@@ -4838,6 +4843,7 @@ def __init__(
48384843
# output through the abi-compatible interface.
48394844
self.outputs: Sequence[Any] = []
48404845
self.python_kernel_name = python_kernel_name
4846+
self.cpp_kernel_name = cpp_kernel_name
48414847

48424848
@classmethod
48434849
def create(cls, kernel, *args, **kwargs):
@@ -4854,7 +4860,11 @@ def create(cls, kernel, *args, **kwargs):
48544860
assert device, "Not sure where to find device info"
48554861

48564862
packed = ComplexView(
4857-
MultiOutputLayout(device), str(kernel), tensor_args, non_tensor_args
4863+
MultiOutputLayout(device),
4864+
str(kernel),
4865+
get_aten_cpp_kernel_name(kernel),
4866+
tensor_args,
4867+
non_tensor_args,
48584868
)
48594869

48604870
layout = FixedLayout(

0 commit comments

Comments
 (0)