Skip to content

Commit f17a038

Browse files
committed
Update
[ghstack-poisoned]
1 parent 2ef7054 commit f17a038

4 files changed

Lines changed: 20 additions & 22 deletions

File tree

test/export/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def forward(self, x):
968968
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
969969
sdpa_score0 = self.sdpa_score0
970970
sdpa_mask0 = self.sdpa_mask0
971-
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
971+
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
972972
getitem = flex_attention[0]
973973
getitem_1 = flex_attention[1]; getitem_1 = None
974974
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None

test/inductor/test_flex_attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,7 +3525,7 @@ def test_kernel_options_argument_is_respected(self, device):
35253525

35263526
@supported_platform
35273527
@skip_on_cpu
3528-
def test_force_impl_default_matches_triton_large(self, device):
3528+
def test_backend_auto_matches_triton_large(self, device):
35293529
"""BACKEND='AUTO' should follow Triton heuristics on large shapes."""
35303530
make_tensor = functools.partial(
35313531
torch.randn,
@@ -3558,7 +3558,7 @@ def compile_and_run(kernel_options):
35583558

35593559
@supported_platform
35603560
@skip_on_cpu
3561-
def test_force_impl_decode_matches_default(self, device):
3561+
def test_backend_triton_decode_matches_auto(self, device):
35623562
"""BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes."""
35633563
make_tensor = functools.partial(
35643564
torch.randn,
@@ -3607,7 +3607,7 @@ def compile_and_run(kernel_options):
36073607

36083608
@supported_platform
36093609
@skip_on_cpu
3610-
def test_force_impl_decode_errors_when_not_supported(self, device):
3610+
def test_backend_triton_decode_errors_when_not_supported(self, device):
36113611
"""Requesting decode on unsupported shapes should raise a helpful error."""
36123612
make_tensor = functools.partial(
36133613
torch.randn,
@@ -3627,7 +3627,7 @@ def test_force_impl_decode_errors_when_not_supported(self, device):
36273627

36283628
@supported_platform
36293629
@skip_on_cpu
3630-
def test_force_impl_decode_errors_with_non_power_of_two_gqa(self, device):
3630+
def test_backend_triton_decode_errors_with_non_power_of_two_gqa(self, device):
36313631
"""BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two."""
36323632
q = torch.randn(
36333633
1, 3, 64, 64, device=device, dtype=torch.float16, requires_grad=False
@@ -3654,7 +3654,7 @@ def test_force_impl_decode_errors_with_non_power_of_two_gqa(self, device):
36543654

36553655
@supported_platform
36563656
@skip_on_cpu
3657-
def test_force_impl_rejects_legacy_force_use_flag(self, device):
3657+
def test_backend_rejects_legacy_force_use_flag(self, device):
36583658
"""Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error."""
36593659
make_tensor = functools.partial(
36603660
torch.randn,
@@ -3681,7 +3681,7 @@ def test_force_impl_rejects_legacy_force_use_flag(self, device):
36813681
)
36823682

36833683
@supported_platform
3684-
def test_force_impl_defaults_and_rejects_invalid(self, device):
3684+
def test_backend_defaults_and_rejects_invalid(self, device):
36853685
device = torch.device(device)
36863686
query = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32)
36873687
key = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32)
@@ -4333,7 +4333,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_
43334333
43344334
score_mod_0 = self.score_mod_0
43354335
mask_fn_0 = self.mask_fn_0
4336-
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
4336+
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
43374337
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
43384338
return (out,)
43394339
@@ -4369,11 +4369,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
43694369
"""\
43704370
class GraphModule(torch.nn.Module):
43714371
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
4372-
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
4372+
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
43734373
fw_graph0 = self.fw_graph0
43744374
joint_graph0 = self.joint_graph0
43754375
mask_graph0 = self.mask_graph0
4376-
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
4376+
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
43774377
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
43784378
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
43794379
getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
@@ -4393,7 +4393,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
43934393
43944394
class mask_graph0(torch.nn.Module):
43954395
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
4396-
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
4396+
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
43974397
return full_default
43984398
""".replace( # noqa: B950
43994399
"GPU_TYPE", torch.device(device).type

torch/_inductor/kernel/flex/flex_attention.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def _sanitize_kernel_options_for_triton(
5959
to avoid passing to triton constexpr dict
6060
"""
6161
sanitized = dict(kernel_options)
62-
force_impl = cast(_Backend, sanitized.pop("BACKEND", "AUTO"))
63-
return sanitized, force_impl
62+
backend = cast(_Backend, sanitized.pop("BACKEND", "AUTO"))
63+
return sanitized, backend
6464

6565

6666
@SymbolicGridFn
@@ -182,7 +182,7 @@ def flex_attention(
182182
)
183183
freeze_irnodes(mask_graph_buffer)
184184

185-
kernel_options, force_impl = _sanitize_kernel_options_for_triton(kernel_options)
185+
kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options)
186186
# Mark symbols in custom kernel options as static shapes and add guards.
187187
kernel_options = {
188188
k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v
@@ -196,11 +196,9 @@ def flex_attention(
196196
can_use_decode = _use_flex_decoding(
197197
query, kv_indices, value, kernel_options, enable_gqa
198198
)
199-
use_decode = (force_impl == "TRITON_DECODE") or (
200-
force_impl == "AUTO" and can_use_decode
201-
)
199+
use_decode = (backend == "TRITON_DECODE") or (backend == "AUTO" and can_use_decode)
202200

203-
if force_impl == "TRITON_DECODE" and not can_use_decode:
201+
if backend == "TRITON_DECODE" and not can_use_decode:
204202
raise RuntimeError(
205203
"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used for this input. "
206204
"flex_decoding is only available for short sequence lengths with specific configurations."
@@ -253,7 +251,7 @@ def flex_attention(
253251
mask_graph,
254252
kernel_options,
255253
num_score_mod_placeholders=len(placeholder_inps),
256-
force_impl=force_impl,
254+
backend=backend,
257255
):
258256
return create_flex_flash_attention_kernel(
259257
query,

torch/_inductor/kernel/flex/flex_flash_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _use_flex_flash_attention(
171171
mask_graph: Subgraph,
172172
kernel_options: dict[str, Any],
173173
num_score_mod_placeholders: int,
174-
force_impl: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"],
174+
backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"],
175175
) -> bool:
176176
"""Determine if we should use flex flash attention for the given inputs.
177177
@@ -180,13 +180,13 @@ def _use_flex_flash_attention(
180180
mask_graph: The mask modification subgraph
181181
kernel_options: Kernel configuration options
182182
num_score_mod_placeholders: Number of placeholders in score_mod
183-
force_impl: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE)
183+
backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE)
184184
185185
Returns:
186186
True if flash attention should be used, False otherwise
187187
"""
188188
# Flash is experimental and must be explicitly requested
189-
if force_impl != "FLASH":
189+
if backend != "FLASH":
190190
return False
191191

192192
can_use, reason = _can_use_flex_flash_attention(

0 commit comments

Comments
 (0)