Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 53 additions & 125 deletions test/inductor/test_custom_op_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,115 +216,6 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}"
)

@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants.

Validates parametric tuning where the same decomposition function uses different
algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights).
"""
test_op_name = f"test_lib::mlp_{id(self)}"

def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""

if method == 0:
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)

elif method == 1:
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]

input_2d = input_tensor.view(-1, hidden_dim)

gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)

gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)

return output_2d.view(*batch_shape, output_dim)

@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)

@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)

# Use explicit config with method parameter as tuning knob
register_custom_op_autotuning(
test_mlp_op,
configs=[
CustomOpConfig(method=0),
CustomOpConfig(method=1),
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)

# Create test inputs
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()

# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)

# Test autotuning
self._run_autotune_test(
test_mlp_op,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)

def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
Expand All @@ -335,12 +226,12 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):

@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parametric tuning for k_splits values.
"""Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale).

Validates numerical parameter sweep where k_splits controls how the K dimension
is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]).
Validates that the custom op encapsulates the entire fused operation with parametric
tuning for k_splits values controlling how the K dimension is decomposed.
"""
test_op_name = f"test_lib::decompose_k_{id(self)}"
test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}"

def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
Expand All @@ -363,19 +254,23 @@ def decompose_k_implementation(
return torch.sum(result, dim=0) # [m, n]

@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
def matmul_relu_epilogue_op(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
return decompose_k_implementation(a, b, k_splits)

@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
"""Matmul with decompose_k + bias + relu + scale (complete epilogue fusion)."""
matmul_result = decompose_k_implementation(a, b, k_splits)
biased = matmul_result + bias
activated = torch.relu(biased)
scaled = activated * 2.0
return scaled

@matmul_relu_epilogue_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)

# Register autotuning with different k_splits values using decomposition function
# Register autotuning with different k_splits values
register_custom_op_autotuning(
test_decompose_k_op,
matmul_relu_epilogue_op,
configs=[
CustomOpConfig(k_splits=2),
CustomOpConfig(k_splits=4),
Expand All @@ -385,7 +280,7 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
CustomOpConfig(k_splits=64),
CustomOpConfig(k_splits=128),
],
name="test_decompose_k_autotuned",
name="matmul_relu_epilogue_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
Expand All @@ -395,12 +290,45 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
fake_tensor, device=self.device
)
* 0.1,
"bias": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
},
)

# Create test inputs
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK")
bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1

# Compile the model using the custom op
@torch.compile
def test_model(a, b, bias):
return matmul_relu_epilogue_op(a, b, bias)

torch._dynamo.reset()

with config.patch(
max_autotune=True,
benchmark_fusion=True,
):
compiled_result = test_model(a, b, bias)

def reference_model(a, b, bias):
matmul_result = a @ b
biased = matmul_result + bias
activated = torch.relu(biased)
scaled = activated * 2.0
return scaled

expected = reference_model(a, b, bias)

torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
)

@skipIfXpu
def test_multi_parameter_tuning(self):
Expand Down
25 changes: 24 additions & 1 deletion torch/_inductor/codegen/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
log = logging.getLogger(__name__)


def inline_subgraph_to_ir_nodes(
gm: torch.fx.GraphModule, inputs: list[Any], name: str
) -> Any:
"""Inline a subgraph by converting its FX operations to individual IR nodes.

This converts a subgraph to multiple ComputedBuffer nodes (fusable),
enabling epilogue fusion with subsequent operations.

Returns:
TensorBox containing the final operation result as individual IR nodes
"""
from torch._inductor.lowering import process_subgraph_nodes

return process_subgraph_nodes(gm, inputs)


class SubgraphChoiceCaller(ir.ChoiceCaller):
"""
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
Expand Down Expand Up @@ -261,7 +277,14 @@ def make_fx_graph(
# decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
from torch.fx.experimental.proxy_tensor import make_fx

return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
from ..decomposition import select_decomp_table

decomposition_table = select_decomp_table()

return make_fx(
functools.partial(decomp, **decomp_kwargs),
decomposition_table=decomposition_table,
)(*args)

# Generate descriptive name for this variant
variant_name = self._generate_variant_name(decomp, decomp_kwargs)
Expand Down
43 changes: 31 additions & 12 deletions torch/_inductor/kernel/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Optional, Union

import torch
from torch._inductor import config
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
from torch._inductor.lowering import lowerings, validate_ir
Expand Down Expand Up @@ -157,7 +158,6 @@ def _adapt_user_input_gen_fns(

Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
from torch._inductor import config

name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
Expand Down Expand Up @@ -237,6 +237,7 @@ def autotune_custom_op(

This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
After selecting the best choice, applies inline fusion if the winning choice has a graph.

Args:
name: Unique identifier for the autotuning operation
Expand Down Expand Up @@ -319,14 +320,34 @@ def autotune_custom_op(
)
input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)

return autotune_select_algorithm(
# Run autotuning and get both result and winning choice
selected_result, winning_choice = autotune_select_algorithm(
name=name,
choices=choices,
input_nodes=list(inputs),
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
return_choice=True,
)

# Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl)
if winning_choice.gm is not None:
log.debug(
"Inlining winning choice: %s (name=%s)",
getattr(winning_choice, "name", type(winning_choice).__name__),
name,
)
from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes

return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)

log.debug(
"Winning choice does not support inlining: %s (name=%s)",
getattr(winning_choice, "name", type(winning_choice).__name__),
name,
)
return selected_result


def register_custom_op_autotuning(
custom_op: torch._library.custom_ops.CustomOpDef,
Expand Down Expand Up @@ -359,7 +380,7 @@ def my_attention(query, key, value, head_dim=32):
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
},
)
"""
from torch._library.custom_ops import CustomOpDef
Expand All @@ -377,12 +398,12 @@ def my_attention(query, key, value, head_dim=32):
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")

processed_configs = []
for config in configs:
if isinstance(config, CustomOpConfig):
processed_configs.append(config)
for cfg in configs:
if isinstance(cfg, CustomOpConfig):
processed_configs.append(cfg)
else:
raise TypeError(
f"Each config must be a CustomOpConfig object, got {type(config)}"
f"Each config must be a CustomOpConfig object, got {type(cfg)}"
)

if not processed_configs:
Expand All @@ -401,14 +422,12 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
decompositions = []
non_tensor_args = []

for config in processed_configs:
decomp = config.get_decomposition(default_impl=default_impl)
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)

# Merge config params with runtime kwargs (runtime takes precedence)
merged_kwargs = _merge_config_and_runtime_kwargs(
config.params, runtime_kwargs
)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)

result = autotune_custom_op(
Expand Down
Loading
Loading