Skip to content

Commit 48fd996

Browse files
committed
feat(turbo-kv): support head_dim=256 via two 128-dim sub-groups
Qwen3.5-122B has head_dim=256. The C++ encoder now processes D=256 as two consecutive D=128 sub-groups using the existing TurboQuantK/V structs. Record sizes double: K=136b, V=100b per token for D=256.
1 parent 4143d3b commit 48fd996

2 files changed

Lines changed: 52 additions & 36 deletions

File tree

LocalPackages/mlx-swift/Source/Cmlx/mlx/mlx/fast.cpp

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -947,22 +947,25 @@ bool ConvertFP8::is_equivalent(const Primitive& other) const {
947947
// turbo_quant.h opens its own namespace mlx::core::fast so it must be
948948
// included OUTSIDE our namespace block to avoid double-nesting.
949949
//
950-
// K record layout (68 bytes per head_dim=128 vector):
950+
// K record layout per 128-dim sub-group (68 bytes):
951951
// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
952952
// [ 48.. 63] qjl_signs[16] — 1-bit QJL sign bits
953953
// [ 64.. 65] norm_fp16[2] — original L2 norm as fp16
954954
// [ 66.. 67] rnorm_fp16[2] — residual L2 norm as fp16
955955
//
956-
// V record layout (50 bytes per head_dim=128 vector):
956+
// V record layout per 128-dim sub-group (50 bytes):
957957
// [ 0.. 47] indices[48] — 3-bit PolarQuant indices, LSB-packed
958958
// [ 48.. 49] norm_fp16[2] — corrected L2 norm as fp16
959+
//
960+
// For head_dim=256 we process two consecutive 128-dim sub-groups;
961+
// the packed record is simply 2x the single-group size.
959962

960963
#include "mlx/fast/turbo_quant.h" // brings in mlx::core::fast types
961964

962965
namespace {
963-
// Record byte sizes — must match sdpa_vector.h (Metal decompression kernel).
964-
static constexpr int TURBO_K_RECORD = 68;
965-
static constexpr int TURBO_V_RECORD = 50;
966+
// Record byte sizes per 128-dim sub-group — must match sdpa_vector.h.
967+
static constexpr int TURBO_K_RECORD = 68; // one 128-dim K sub-group
968+
static constexpr int TURBO_V_RECORD = 50; // one 128-dim V sub-group
966969
} // anonymous namespace
967970

968971
namespace mlx::core::fast {
@@ -979,62 +982,75 @@ turbo_to_f32(const mlx::core::array& x, mlx::core::StreamOrDevice s) {
979982
array turbo_encode_k(const array& keys, StreamOrDevice s_) {
980983
auto s = to_stream(s_);
981984

982-
if (keys.shape(-1) != ::mlx::core::fast::TURBO_D) {
985+
const int head_dim = static_cast<int>(keys.shape(-1));
986+
if (head_dim != 128 && head_dim != 256) {
983987
throw std::invalid_argument(
984-
"[turbo_encode_k] last dim (head_dim) must be " +
985-
std::to_string(::mlx::core::fast::TURBO_D) + " but got " +
986-
std::to_string(keys.shape(-1)));
988+
"[turbo_encode_k] last dim (head_dim) must be 128 or 256 but got " +
989+
std::to_string(head_dim));
987990
}
988991

992+
// For D=256 we split each vector into two consecutive 128-dim sub-groups.
993+
const int n_subgroups = head_dim / ::mlx::core::fast::TURBO_D; // 1 or 2
994+
const int record_bytes = TURBO_K_RECORD * n_subgroups;
995+
989996
auto [keys_f32, src] = turbo_to_f32(keys, s);
990-
const int N = static_cast<int>(keys_f32.size() / ::mlx::core::fast::TURBO_D);
997+
const int N = static_cast<int>(keys_f32.size() / head_dim);
991998

992-
std::vector<uint8_t> buf(static_cast<size_t>(N) * TURBO_K_RECORD, 0u);
999+
std::vector<uint8_t> buf(static_cast<size_t>(N) * record_bytes, 0u);
9931000

9941001
for (int i = 0; i < N; ++i) {
995-
::mlx::core::fast::TurboQuantK rec =
996-
::mlx::core::fast::turbo_quantize_k(
997-
src + i * ::mlx::core::fast::TURBO_D,
998-
::mlx::core::fast::TURBO_D);
999-
uint8_t* dst = buf.data() + i * TURBO_K_RECORD;
1000-
std::memcpy(dst, rec.indices, 48);
1001-
std::memcpy(dst + 48, rec.qjl_signs, 16);
1002-
std::memcpy(dst + 64, &rec.norm_fp16, 2);
1003-
std::memcpy(dst + 66, &rec.rnorm_fp16, 2);
1002+
uint8_t* dst = buf.data() + i * record_bytes;
1003+
for (int g = 0; g < n_subgroups; ++g) {
1004+
::mlx::core::fast::TurboQuantK rec =
1005+
::mlx::core::fast::turbo_quantize_k(
1006+
src + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1007+
::mlx::core::fast::TURBO_D);
1008+
uint8_t* sub_dst = dst + g * TURBO_K_RECORD;
1009+
std::memcpy(sub_dst, rec.indices, 48);
1010+
std::memcpy(sub_dst + 48, rec.qjl_signs, 16);
1011+
std::memcpy(sub_dst + 64, &rec.norm_fp16, 2);
1012+
std::memcpy(sub_dst + 66, &rec.rnorm_fp16, 2);
1013+
}
10041014
}
10051015

10061016
Shape out_shape = keys.shape();
1007-
out_shape.back() = TURBO_K_RECORD;
1017+
out_shape.back() = record_bytes;
10081018
return array(buf.data(), out_shape, uint8);
10091019
}
10101020

10111021
array turbo_encode_v(const array& values, StreamOrDevice s_) {
10121022
auto s = to_stream(s_);
10131023

1014-
if (values.shape(-1) != ::mlx::core::fast::TURBO_D) {
1024+
const int head_dim = static_cast<int>(values.shape(-1));
1025+
if (head_dim != 128 && head_dim != 256) {
10151026
throw std::invalid_argument(
1016-
"[turbo_encode_v] last dim (head_dim) must be " +
1017-
std::to_string(::mlx::core::fast::TURBO_D) + " but got " +
1018-
std::to_string(values.shape(-1)));
1027+
"[turbo_encode_v] last dim (head_dim) must be 128 or 256 but got " +
1028+
std::to_string(head_dim));
10191029
}
10201030

1031+
const int n_subgroups = head_dim / ::mlx::core::fast::TURBO_D; // 1 or 2
1032+
const int record_bytes = TURBO_V_RECORD * n_subgroups;
1033+
10211034
auto [vals_f32, src] = turbo_to_f32(values, s);
1022-
const int N = static_cast<int>(vals_f32.size() / ::mlx::core::fast::TURBO_D);
1035+
const int N = static_cast<int>(vals_f32.size() / head_dim);
10231036

1024-
std::vector<uint8_t> buf(static_cast<size_t>(N) * TURBO_V_RECORD, 0u);
1037+
std::vector<uint8_t> buf(static_cast<size_t>(N) * record_bytes, 0u);
10251038

10261039
for (int i = 0; i < N; ++i) {
1027-
::mlx::core::fast::TurboQuantV rec =
1028-
::mlx::core::fast::turbo_quantize_v(
1029-
src + i * ::mlx::core::fast::TURBO_D,
1030-
::mlx::core::fast::TURBO_D);
1031-
uint8_t* dst = buf.data() + i * TURBO_V_RECORD;
1032-
std::memcpy(dst, rec.indices, 48);
1033-
std::memcpy(dst + 48, &rec.norm_fp16, 2);
1040+
uint8_t* dst = buf.data() + i * record_bytes;
1041+
for (int g = 0; g < n_subgroups; ++g) {
1042+
::mlx::core::fast::TurboQuantV rec =
1043+
::mlx::core::fast::turbo_quantize_v(
1044+
src + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1045+
::mlx::core::fast::TURBO_D);
1046+
uint8_t* sub_dst = dst + g * TURBO_V_RECORD;
1047+
std::memcpy(sub_dst, rec.indices, 48);
1048+
std::memcpy(sub_dst + 48, &rec.norm_fp16, 2);
1049+
}
10341050
}
10351051

10361052
Shape out_shape = values.shape();
1037-
out_shape.back() = TURBO_V_RECORD;
1053+
out_shape.back() = record_bytes;
10381054
return array(buf.data(), out_shape, uint8);
10391055
}
10401056

mlx-swift-lm

0 commit comments

Comments
 (0)