Skip to content

Commit 7433744

Browse files
committed
Fix test on windows
[ghstack-poisoned]
1 parent 1c77b13 commit 7433744

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

test/test_transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,9 @@ def test_fused_sdp_choice(self, device, type: str):
28172817
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
28182818
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
28192819
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows
2820-
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
2820+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value)
2821+
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2822+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
28212823
else:
28222824
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value)
28232825

0 commit comments

Comments
 (0)