Conversation
|
Top! Thanks @Blaizzy |
|
I wanted to see how quick i could do this and got something working. Please feel free to reuse anything: kernelpool/mlx-lm@kimi-linear |
|
@kernelpool very nice! What about sending a latch to this and we can merge it in? Or is it simpler to send a separate PR? |
|
Thank you very much @kernelpool!🚀 @awni the fixes are merged here 👌🏽 |
|
uv run python -m mlx_lm generate --model mlx-community/Kimi-Linear-48B-A3B-Instruct-4bit --prompt "hello" -m 1024 --trust-remote-code Fetching 16 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 54251.30it/s]
|
|
Clone my fork and Install from source |
mlx_lm/models/gated_delta.py
Outdated
| def _make_gated_delta_kernel_vec(has_mask: bool = False): | ||
| if not mx.metal.is_available(): | ||
| return None |
There was a problem hiding this comment.
This looks like a duplicate of the above kernel. Why do we need this? Shouldn't we just reuse the above kernel?
There was a problem hiding this comment.
This was based on the FLA implementation where theres a separate kernel to handle vectorized gating. But yeah, this can be simplified.
mlx_lm/models/gated_delta.py
Outdated
| if q.shape[1] > chunk_size: | ||
| return chunked_gated_delta_kernel( | ||
| q, | ||
| k, | ||
| v, | ||
| g, | ||
| beta, | ||
| state, | ||
| mask, | ||
| chunk_size, | ||
| ) |
There was a problem hiding this comment.
What's the purpose of that over just using the gated_delta_kernel directly?
mlx_lm/models/gated_delta.py
Outdated
| if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available(): | ||
| return gated_delta_ops(q, k, v, g, beta, state, mask) | ||
| else: | ||
| return gated_delta_kernel(q, k, v, g, beta, state, mask) | ||
| from . import fused_recurrent_kda as frkda | ||
|
|
||
| if q.shape[1] > chunk_size: | ||
| return frkda.chunked_kda_ops(q, k, v, g, beta, state, mask, chunk_size) | ||
| return frkda.fused_recurrent_kda_ops(q, k, v, g, beta, state, mask) |
There was a problem hiding this comment.
It looks like we switched to a different function here (gated_delta_ops replaced by the frkda function. Why? As far as I can see there should be no difference.
|
The changes to the It would much cleaner to reuse the existing operations (which should be doable). If there is an efficiency implication I'd love to know more. Not sure who worked on that @Blaizzy or @kernelpool, would one of you be up for improving that? |
|
Sure, I'll take a look! |
|
Hey @awni Yes, the kernels are yet to be optimized. I personally believe we can simplify it and should have it in the same file as the model until we see more models using them. So far I optimized the overall model code (2 tok/s to 70 tok/s in bf16) but have my plate full for this week so I will only be able to pick it up during the weekend. |
|
Regardless of whether they can be optimized, I prefer not to use new kernels and ops but rather the existing ones we already have.
It looks like these are the same operations as what we have for Qwen3 Next so I would keep them in the gated delta file. |
Yes, I prefer that too. Unfortunetly, I didn't work on the kernels and that's why I wanted to use the weekend to dive deep into the codebase. |
|
I pushed a PR to simplify, unify kernels, and with the chunking removed. I also used @ivanfioravanti's benchmark script to measure differences between the commits (3874bc6 is the head of the PR)
|
awni
left a comment
There was a problem hiding this comment.
Looks great, thanks for the contributions @kernelpool and @Blaizzy
# Conflicts: # mlx_lm/models/kimi_linear.py


No description provided.