@@ -947,22 +947,25 @@ bool ConvertFP8::is_equivalent(const Primitive& other) const {
947947// turbo_quant.h opens its own namespace mlx::core::fast so it must be
948948// included OUTSIDE our namespace block to avoid double-nesting.
949949//
950- // K record layout (68 bytes per head_dim= 128 vector ):
950+ // K record layout per 128-dim sub-group (68 bytes ):
951951// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
952952// [ 48.. 63] qjl_signs[16] — 1-bit QJL sign bits
953953// [ 64.. 65] norm_fp16[2] — original L2 norm as fp16
954954// [ 66.. 67] rnorm_fp16[2] — residual L2 norm as fp16
955955//
956- // V record layout (50 bytes per head_dim= 128 vector ):
956+ // V record layout per 128-dim sub-group (50 bytes ):
957957// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
958958// [ 48.. 49] norm_fp16[2] — corrected L2 norm as fp16
959+ //
960+ // For head_dim=256 we process two consecutive 128-dim sub-groups;
961+ // the packed record is simply 2x the single-group size.
959962
960963#include " mlx/fast/turbo_quant.h" // brings in mlx::core::fast types
961964
962965namespace {
963- // Record byte sizes — must match sdpa_vector.h (Metal decompression kernel) .
964- static constexpr int TURBO_K_RECORD = 68 ;
965- static constexpr int TURBO_V_RECORD = 50 ;
966+ // Record byte sizes per 128-dim sub-group — must match sdpa_vector.h.
967+ static constexpr int TURBO_K_RECORD = 68 ; // one 128-dim K sub-group
968+ static constexpr int TURBO_V_RECORD = 50 ; // one 128-dim V sub-group
966969} // anonymous namespace
967970
968971namespace mlx ::core::fast {
@@ -979,62 +982,75 @@ turbo_to_f32(const mlx::core::array& x, mlx::core::StreamOrDevice s) {
979982array turbo_encode_k (const array& keys, StreamOrDevice s_) {
980983 auto s = to_stream (s_);
981984
982- if (keys.shape (-1 ) != ::mlx::core::fast::TURBO_D) {
985+ const int head_dim = static_cast <int >(keys.shape (-1 ));
986+ if (head_dim != 128 && head_dim != 256 ) {
983987 throw std::invalid_argument (
984- " [turbo_encode_k] last dim (head_dim) must be " +
985- std::to_string (::mlx::core::fast::TURBO_D) + " but got " +
986- std::to_string (keys.shape (-1 )));
988+ " [turbo_encode_k] last dim (head_dim) must be 128 or 256 but got " +
989+ std::to_string (head_dim));
987990 }
988991
992+ // For D=256 we split each vector into two consecutive 128-dim sub-groups.
993+ const int n_subgroups = head_dim / ::mlx::core::fast::TURBO_D; // 1 or 2
994+ const int record_bytes = TURBO_K_RECORD * n_subgroups;
995+
989996 auto [keys_f32, src] = turbo_to_f32 (keys, s);
990- const int N = static_cast <int >(keys_f32.size () / ::mlx::core::fast::TURBO_D );
997+ const int N = static_cast <int >(keys_f32.size () / head_dim );
991998
992- std::vector<uint8_t > buf (static_cast <size_t >(N) * TURBO_K_RECORD , 0u );
999+ std::vector<uint8_t > buf (static_cast <size_t >(N) * record_bytes , 0u );
9931000
9941001 for (int i = 0 ; i < N; ++i) {
995- ::mlx::core::fast::TurboQuantK rec =
996- ::mlx::core::fast::turbo_quantize_k (
997- src + i * ::mlx::core::fast::TURBO_D,
998- ::mlx::core::fast::TURBO_D);
999- uint8_t * dst = buf.data () + i * TURBO_K_RECORD;
1000- std::memcpy (dst, rec.indices , 48 );
1001- std::memcpy (dst + 48 , rec.qjl_signs , 16 );
1002- std::memcpy (dst + 64 , &rec.norm_fp16 , 2 );
1003- std::memcpy (dst + 66 , &rec.rnorm_fp16 , 2 );
1002+ uint8_t * dst = buf.data () + i * record_bytes;
1003+ for (int g = 0 ; g < n_subgroups; ++g) {
1004+ ::mlx::core::fast::TurboQuantK rec =
1005+ ::mlx::core::fast::turbo_quantize_k (
1006+ src + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1007+ ::mlx::core::fast::TURBO_D);
1008+ uint8_t * sub_dst = dst + g * TURBO_K_RECORD;
1009+ std::memcpy (sub_dst, rec.indices , 48 );
1010+ std::memcpy (sub_dst + 48 , rec.qjl_signs , 16 );
1011+ std::memcpy (sub_dst + 64 , &rec.norm_fp16 , 2 );
1012+ std::memcpy (sub_dst + 66 , &rec.rnorm_fp16 , 2 );
1013+ }
10041014 }
10051015
10061016 Shape out_shape = keys.shape ();
1007- out_shape.back () = TURBO_K_RECORD ;
1017+ out_shape.back () = record_bytes ;
10081018 return array (buf.data (), out_shape, uint8);
10091019}
10101020
10111021array turbo_encode_v (const array& values, StreamOrDevice s_) {
10121022 auto s = to_stream (s_);
10131023
1014- if (values.shape (-1 ) != ::mlx::core::fast::TURBO_D) {
1024+ const int head_dim = static_cast <int >(values.shape (-1 ));
1025+ if (head_dim != 128 && head_dim != 256 ) {
10151026 throw std::invalid_argument (
1016- " [turbo_encode_v] last dim (head_dim) must be " +
1017- std::to_string (::mlx::core::fast::TURBO_D) + " but got " +
1018- std::to_string (values.shape (-1 )));
1027+ " [turbo_encode_v] last dim (head_dim) must be 128 or 256 but got " +
1028+ std::to_string (head_dim));
10191029 }
10201030
1031+ const int n_subgroups = head_dim / ::mlx::core::fast::TURBO_D; // 1 or 2
1032+ const int record_bytes = TURBO_V_RECORD * n_subgroups;
1033+
10211034 auto [vals_f32, src] = turbo_to_f32 (values, s);
1022- const int N = static_cast <int >(vals_f32.size () / ::mlx::core::fast::TURBO_D );
1035+ const int N = static_cast <int >(vals_f32.size () / head_dim );
10231036
1024- std::vector<uint8_t > buf (static_cast <size_t >(N) * TURBO_V_RECORD , 0u );
1037+ std::vector<uint8_t > buf (static_cast <size_t >(N) * record_bytes , 0u );
10251038
10261039 for (int i = 0 ; i < N; ++i) {
1027- ::mlx::core::fast::TurboQuantV rec =
1028- ::mlx::core::fast::turbo_quantize_v (
1029- src + i * ::mlx::core::fast::TURBO_D,
1030- ::mlx::core::fast::TURBO_D);
1031- uint8_t * dst = buf.data () + i * TURBO_V_RECORD;
1032- std::memcpy (dst, rec.indices , 48 );
1033- std::memcpy (dst + 48 , &rec.norm_fp16 , 2 );
1040+ uint8_t * dst = buf.data () + i * record_bytes;
1041+ for (int g = 0 ; g < n_subgroups; ++g) {
1042+ ::mlx::core::fast::TurboQuantV rec =
1043+ ::mlx::core::fast::turbo_quantize_v (
1044+ src + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1045+ ::mlx::core::fast::TURBO_D);
1046+ uint8_t * sub_dst = dst + g * TURBO_V_RECORD;
1047+ std::memcpy (sub_dst, rec.indices , 48 );
1048+ std::memcpy (sub_dst + 48 , &rec.norm_fp16 , 2 );
1049+ }
10341050 }
10351051
10361052 Shape out_shape = values.shape ();
1037- out_shape.back () = TURBO_V_RECORD ;
1053+ out_shape.back () = record_bytes ;
10381054 return array (buf.data (), out_shape, uint8);
10391055}
10401056
0 commit comments