@@ -1554,17 +1554,17 @@ def __init__(self) -> None:
15541554 (torch .float32 , 64 ): ROCmFlexConfig (128 , 32 , 1 , 4 ),
15551555 (torch .float32 , 128 ): ROCmFlexConfig (128 , 32 , 1 , 4 ),
15561556 (torch .float32 , 256 ): ROCmFlexConfig (64 , 16 , 1 , 4 ),
1557- (torch .bfloat16 , 64 ): ROCmFlexConfig (128 , 64 , 2 , 8 ),
1558- (torch .bfloat16 , 128 ): ROCmFlexConfig (128 , 64 , 2 , 8 ),
1559- (torch .bfloat16 , 256 ): ROCmFlexConfig (32 , 64 , 2 , 8 ),
1557+ (torch .bfloat16 , 64 ): ROCmFlexConfig (128 , 64 , 2 , 4 ),
1558+ (torch .bfloat16 , 128 ): ROCmFlexConfig (128 , 64 , 2 , 4 ),
1559+ (torch .bfloat16 , 256 ): ROCmFlexConfig (32 , 64 , 2 , 4 ),
15601560 (torch .float16 , 64 ): ROCmFlexConfig (128 , 64 , 2 , 8 ),
15611561 (torch .float16 , 128 ): ROCmFlexConfig (128 , 64 , 2 , 8 ),
15621562 (torch .float16 , 256 ): ROCmFlexConfig (32 , 64 , 2 , 4 ),
15631563 }
15641564
15651565 self .flex_attn_fwd_autotune_configs : list [FlexConfig ] = [
15661566 ROCmFlexConfig (BLOCK1 , BLOCK2 , 1 , w )
1567- for BLOCK1 in [16 , 64 , 128 , 256 ]
1567+ for BLOCK1 in [16 , 64 , 128 ]
15681568 for BLOCK2 in [16 , 32 , 64 , 128 ]
15691569 for w in [4 , 8 ]
15701570 ]
@@ -1590,7 +1590,7 @@ def __init__(self) -> None:
15901590
15911591 self .exhaustive_flex_attn_fwd_configs : list [FlexConfig ] = [
15921592 ROCmFlexConfig (BLOCK_M , BLOCK_N , num_stages , num_warps , mfma , wpeu )
1593- for BLOCK_M in [16 , 32 , 64 , 128 , 256 ]
1593+ for BLOCK_M in [16 , 32 , 64 , 128 ]
15941594 for BLOCK_N in [32 , 64 , 128 ]
15951595 for num_stages in [1 , 2 ]
15961596 for num_warps in [2 , 4 , 8 ]
@@ -1735,7 +1735,7 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi
17351735 if dtype == torch .float32 :
17361736 default_config = ROCmFlexConfig (64 , 64 , 1 , 4 )
17371737 else :
1738- default_config = ROCmFlexConfig (128 , 64 , 2 , 8 )
1738+ default_config = ROCmFlexConfig (128 , 64 , 2 , 4 )
17391739 default_config = self .default_flex_config .get (
17401740 (dtype , head_dim ), default_config
17411741 )
0 commit comments