Add bitnet1.58 with custom metal kernel#219
Conversation
|
There is probably space for further optimizations. I have seen slightly faster inference for this model. But I believe it's a good start! |
|
Fusing qkv can probably get us to around ~50 tokens but I couldn't find a way to join the scales and keep model coherence. |
|
Hi @Blaizzy - the results look impressive, this could potentially be easily ported to other existing Bitnet models such as Falcon-E or Falcon3-1.58 since we can use the same BitNetLinear kernel and are based on Llama architecture. I am happy to work on integrating these models as well, would the |
|
Currently the way to infer if the models use bitnet for these models is to check this attribute: https://huggingface.co/tiiuae/Falcon-E-3B-Instruct/blob/main/config.json#L27 |
Thank you @younesbelkada! Yes, I can put all the BitLinear Layer logic into a separate file that can be re-used. Much like the existing Funny you meantion falcon H1, I worked on a MLX version a few weeks back but in a separate repo. I would love to collaborate with you on it. |
Awesome, that makes things easier! |
|
@Blaizzy - great to hear that ! I'll reach out to you separately by email (the one you put on your GH profile), I would love to colab as well for H1 and Bitnet ! Sending you an email now |
Perfect! I have moved BitLinear layer logic to a separate file for easier reusability across models and projects. |
|
Thank you @Blaizzy will give it a try and I'll either open a PR to your branch or on main if this PR gets merged quickly |
|
My pleasure! You can send a PR to my branch 👌🏽 |
|
Got a working version: Blaizzy#1 |
|
Also works on previous Bitnet models, e.g. https://huggingface.co/tiiuae/Falcon3-7B-Instruct-1.58bit : |
|
Perfect, great job @younesbelkada ! 🚀 |
|
Left a comment in your PR @younesbelkada after that is resolved we can merge. |
|
Thank you @Blaizzy ! Also sent you an email for collab 🙏 |
|
Just added support for N-bit quants for 1.58bit model and it further reduces peak memory and is faster. |
|
@Blaizzy your code looks really nice. One question: your 0,1,2 to -1,0,1 shift is one additional operation can we avoid it ? second question: do you plan to include the fused qkv kernels too ? |
|
I'm afraid we can't avoid it. |
|
Yes, I plan to include the fused QKV kernels as soon as I find the best way to aggregate the weight scales for QKV. |
|
Fused QKV kernel is coming soon 🚀 |
|
so on M3 128 GB: generate at 89.3 +/- 0.1 tok/sec. with --max-tokens 100 => 81 +/-0.2 tok/sec |
|
Q: Do you know where I can find a good example script in order to fine-tuned a Bitnet model or Falcon using mlx ? |
Not sure if this is implemented in MLX, but either for microsoft bitnet or Falcon-E you will need to download the pre-quantized weights (for bitnet it's in a separate repo: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T-bf16 - and for Falcon-E it's on a separate revision |
|
Ok so I posted some updates:
This is good to merge. Thanks for the contributions everyone!! |
|
My pleasure! The llama changes are there to allow for Falcon bitnet. I can make a separate PR if you don’t mind. |
Thanks a lot, I noticed the speed up! (100-> 115 tok/s)
|
Yes I gathered as much. You can send a PR.. I removed it from this one because the added complexity wasn't great. I don't yet have a good suggestions for how to do it in a simpler way.. but that's something that would be good to think through |
* Port mlx-lm bitnet1.58 ml-explore/mlx-lm#219 * update: Update relu2 function to use compile for shapeless input * refactor: Update BitLinear & Format code * update: Add quantization parameters to BaseConfiguration * update: Improve error handling during weight pre-loading in ContentView * update: Add bitnet_b1_58_2b_4t_4bit model configuration to LLMModelFactory * update: Rename relu2 function to reluSquared and refactor implementation * update: ACKNOWLEDGMENTS.md * Improve the bitnet kernel * remove: eliminate reluSquared function from Bitnet.swift * refactor: update kernel
* add bitnet * update activation to relu2 * working bitnet * remove artifacts * remove logging * add custom post quant * fix dtype and add compile * fixed weight unpack * add custom kernel to avoid memory overhead * compile relu2 * fix weight scale * remove unused * add tests and update tuner utils * update acknowledgements * add kernel caching * add act_quant and set float16 as default dtype * use mx.add and move scaling to kernel * remove act quant * move bitlinear layers to separate file * feat: add falcon-e and other bitnet support * refactor: address comments * add support for 1.58bit N-bit quants * 43.85% speedup in generation performance (M3 max) * refactor utils * remove masking (2% gen speed improvement) * add quantization config * test llama bitnet * refactor apply_hf_quant * default threadgroup: 64 -> 32 * add comment * fix prompt processing perf * remove modulo * compile kernel in the constructor * Improve the bitnet kernel * remove benchmark * refactor bitlinear swap * format * remove llama changes * revert utils * faster + cleanup * not trainable * fix tests --------- Co-authored-by: younesbelkada <younes.belkada@tii.ae> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> Co-authored-by: Awni Hannun <awni@apple.com>
* Port mlx-lm bitnet1.58 ml-explore/mlx-lm#219 * update: Update relu2 function to use compile for shapeless input * refactor: Update BitLinear & Format code * update: Add quantization parameters to BaseConfiguration * update: Improve error handling during weight pre-loading in ContentView * update: Add bitnet_b1_58_2b_4t_4bit model configuration to LLMModelFactory * update: Rename relu2 function to reluSquared and refactor implementation * update: ACKNOWLEDGMENTS.md * Improve the bitnet kernel * remove: eliminate reluSquared function from Bitnet.swift * refactor: update kernel








This PR adds support for bitnet1.58 and implements a custom metal kernel that performs matrix multiplication directly on packed weights. This eliminates the need to store unpacked weights in memory. Additionaly, it also allows you to quantize this 1.58bit models using N-bit quants for better performance, read more here.
Models supported:
w/o the kernel

w/ the custom metal kernel
