@@ -3525,7 +3525,7 @@ def test_kernel_options_argument_is_respected(self, device):
35253525
35263526 @supported_platform
35273527 @skip_on_cpu
3528- def test_force_impl_default_matches_triton_large (self , device ):
3528+ def test_backend_auto_matches_triton_large (self , device ):
35293529 """BACKEND='AUTO' should follow Triton heuristics on large shapes."""
35303530 make_tensor = functools .partial (
35313531 torch .randn ,
@@ -3558,7 +3558,7 @@ def compile_and_run(kernel_options):
35583558
35593559 @supported_platform
35603560 @skip_on_cpu
3561- def test_force_impl_decode_matches_default (self , device ):
3561+ def test_backend_triton_decode_matches_auto (self , device ):
35623562 """BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes."""
35633563 make_tensor = functools .partial (
35643564 torch .randn ,
@@ -3607,7 +3607,7 @@ def compile_and_run(kernel_options):
36073607
36083608 @supported_platform
36093609 @skip_on_cpu
3610- def test_force_impl_decode_errors_when_not_supported (self , device ):
3610+ def test_backend_triton_decode_errors_when_not_supported (self , device ):
36113611 """Requesting decode on unsupported shapes should raise a helpful error."""
36123612 make_tensor = functools .partial (
36133613 torch .randn ,
@@ -3627,7 +3627,7 @@ def test_force_impl_decode_errors_when_not_supported(self, device):
36273627
36283628 @supported_platform
36293629 @skip_on_cpu
3630- def test_force_impl_decode_errors_with_non_power_of_two_gqa (self , device ):
3630+ def test_backend_triton_decode_errors_with_non_power_of_two_gqa (self , device ):
36313631 """BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two."""
36323632 q = torch .randn (
36333633 1 , 3 , 64 , 64 , device = device , dtype = torch .float16 , requires_grad = False
@@ -3654,7 +3654,7 @@ def test_force_impl_decode_errors_with_non_power_of_two_gqa(self, device):
36543654
36553655 @supported_platform
36563656 @skip_on_cpu
3657- def test_force_impl_rejects_legacy_force_use_flag (self , device ):
3657+ def test_backend_rejects_legacy_force_use_flag (self , device ):
36583658 """Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error."""
36593659 make_tensor = functools .partial (
36603660 torch .randn ,
@@ -3681,7 +3681,7 @@ def test_force_impl_rejects_legacy_force_use_flag(self, device):
36813681 )
36823682
36833683 @supported_platform
3684- def test_force_impl_defaults_and_rejects_invalid (self , device ):
3684+ def test_backend_defaults_and_rejects_invalid (self , device ):
36853685 device = torch .device (device )
36863686 query = torch .randn (1 , 1 , 4 , 8 , device = device , dtype = torch .float32 )
36873687 key = torch .randn (1 , 1 , 4 , 8 , device = device , dtype = torch .float32 )
@@ -4333,7 +4333,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_
43334333
43344334 score_mod_0 = self.score_mod_0
43354335 mask_fn_0 = self.mask_fn_0
4336- flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
4336+ flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', ' PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
43374337 out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
43384338 return (out,)
43394339
@@ -4369,11 +4369,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
43694369 """\
43704370 class GraphModule(torch.nn.Module):
43714371 def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
4372- full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE ', index=0), pin_memory = False)
4372+ full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda ', index=0), pin_memory = False)
43734373 fw_graph0 = self.fw_graph0
43744374 joint_graph0 = self.joint_graph0
43754375 mask_graph0 = self.mask_graph0
4376- flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
4376+ flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', ' PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
43774377 getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
43784378 getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
43794379 getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
@@ -4393,7 +4393,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
43934393
43944394 class mask_graph0(torch.nn.Module):
43954395 def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
4396- full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE ', index=0), pin_memory = False)
4396+ full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda ', index=0), pin_memory = False)
43974397 return full_default
43984398""" .replace ( # noqa: B950
43994399 "GPU_TYPE" , torch .device (device ).type
0 commit comments