Skip to content

Add bitnet1.58 with custom metal kernel#219

Merged
awni merged 42 commits intoml-explore:mainfrom
Blaizzy:pc/add-bitnet
Jul 2, 2025
Merged

Add bitnet1.58 with custom metal kernel#219
awni merged 42 commits intoml-explore:mainfrom
Blaizzy:pc/add-bitnet

Conversation

@Blaizzy
Copy link
Contributor

@Blaizzy Blaizzy commented Jun 8, 2025

This PR adds support for bitnet1.58 and implements a custom metal kernel that performs matrix multiplication directly on packed weights. This eliminates the need to store unpacked weights in memory. Additionaly, it also allows you to quantize this 1.58bit models using N-bit quants for better performance, read more here.

Models supported:

w/o the kernel
Screenshot 2025-06-08 at 11 17 17 PM

w/ the custom metal kernel
Screenshot 2025-06-08 at 11 17 34 PM

Note: I removed the N-bit input quantization because it's slower (-6 tokens/s) and doesn't provide significant memory savings, even with fused kernels. I believe it's best to use KV quant instead.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 8, 2025

Small but mighty!
Screenshot 2025-06-09 at 12 25 28 AM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 8, 2025

There is probably space for further optimizations. I have seen slightly faster inference for this model.

But I believe it's a good start!

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 9, 2025

Implemented kernel caching and reduced precision from float32 to float16, achieving significant performance gains across all metrics:

Key Improvements:

  • Throughput: 2x increase overall
  • Peak TFLOPS: 45.86 TFLOPS (up from 20.19) — 127% improvement
  • MLP Forward Pass: 40% faster execution
  • Prompt Processing: 135.046 tokens/sec (up from 51) — 165% improvement
  • Generation Speed: 46.209 tokens/sec (up from 43.830) — 5% improvement
  • Peak Memory: 1.308 GB (down from 1.322 GB) — 1% reduction
Screenshot 2025-06-09 at 4 47 08 PM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 9, 2025

Fusing qkv can probably get us to around ~50 tokens but I couldn't find a way to join the scales and keep model coherence.

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 10, 2025

Hi @Blaizzy - the results look impressive, this could potentially be easily ported to other existing Bitnet models such as Falcon-E or Falcon3-1.58 since we can use the same BitNetLinear kernel and are based on Llama architecture. I am happy to work on integrating these models as well, would the BitLinear class be re-usable by other models? How do you think we should approach this?

@younesbelkada
Copy link
Contributor

Currently the way to infer if the models use bitnet for these models is to check this attribute: https://huggingface.co/tiiuae/Falcon-E-3B-Instruct/blob/main/config.json#L27

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

Hi @Blaizzy - the results look impressive, this could potentially be easily ported to other existing Bitnet models such as Falcon-E or Falcon3-1.58 since we can use the same BitNetLinear kernel and are based on Llama architecture. I am happy to work on integrating these models as well, would the BitLinear class be re-usable by other models? How do you think we should approach this?

Thank you @younesbelkada!

Yes, I can put all the BitLinear Layer logic into a separate file that can be re-used. Much like the existing switch_layers.py

Funny you meantion falcon H1, I worked on a MLX version a few weeks back but in a separate repo. I would love to collaborate with you on it.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

Currently the way to infer if the models use bitnet for these models is to check this attribute: https://huggingface.co/tiiuae/Falcon-E-3B-Instruct/blob/main/config.json#L27

Awesome, that makes things easier!

@younesbelkada
Copy link
Contributor

@Blaizzy - great to hear that ! I'll reach out to you separately by email (the one you put on your GH profile), I would love to colab as well for H1 and Bitnet ! Sending you an email now

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

@Blaizzy - great to hear that ! I'll reach out to you separately by email (the one you put on your GH profile), I would love to colab as well for H1 and Bitnet ! Sending you an email now

Perfect!

I have moved BitLinear layer logic to a separate file for easier reusability across models and projects.

@younesbelkada
Copy link
Contributor

Thank you @Blaizzy will give it a try and I'll either open a PR to your branch or on main if this PR gets merged quickly

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

My pleasure!

You can send a PR to my branch 👌🏽

@younesbelkada
Copy link
Contributor

Got a working version: Blaizzy#1

Screenshot 2025-06-10 at 2 35 45 PM

@younesbelkada
Copy link
Contributor

Also works on previous Bitnet models, e.g. https://huggingface.co/tiiuae/Falcon3-7B-Instruct-1.58bit :

Screenshot 2025-06-10 at 2 41 36 PM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

Perfect, great job @younesbelkada ! 🚀

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

I was about to text you saying I tried Falcon 3 but the hidden states where exploding to Infinity 😂

But the inverse change fixes it ✅

Screenshot 2025-06-10 at 12 46 31 PM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

Left a comment in your PR @younesbelkada after that is resolved we can merge.

@younesbelkada
Copy link
Contributor

Thank you @Blaizzy ! Also sent you an email for collab 🙏
Looking forward to seeing this merged 🚀

@younesbelkada
Copy link
Contributor

Performance of the 1B model:

Screenshot 2025-06-10 at 4 44 48 PM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 10, 2025

Just added support for N-bit quants for 1.58bit model and it further reduces peak memory and is faster.

@awni @angeloskath
Screenshot 2025-06-10 at 3 48 45 PM

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 12, 2025

This last commit is small but makes a huge difference!

MLX is now officially faster than bitnet.cpp, without even using N-bit quants that add +10 tokens/s.

And I believe with fused qkv kernels we can get an additional 5-10 tokens/s :)

Screenshot 2025-06-12 at 3 56 41 AM

2T = 2 threads
4T = 4 threads

@guillaume-osmo
Copy link

@Blaizzy your code looks really nice. One question: your 0,1,2 to -1,0,1 shift is one additional operation can we avoid it ? second question: do you plan to include the fused qkv kernels too ?

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 12, 2025

I'm afraid we can't avoid it.

•	2-bit packing – 4 weights per byte means each weight is an unsigned 2-bit code (00–11).
•	Ternary encoding – store {-1, 0, +1} as {00, 01, 10} and recover with bits - 1.
•	Mat-mul runs in FP16/FP32 – inputs are floats and Apple GPUs multiply-accumulate only in FP16/32 (or INT8 after unpacking), so every weight must be promoted to FP anyway.
•	Negligible cost – float(bits) - 1 is a single fused ALU op; any alternate encoding adds more instructions or an extra reduction.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 12, 2025

Yes, I plan to include the fused QKV kernels as soon as I find the best way to aggregate the weight scales for QKV.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jun 12, 2025

Fused QKV kernel is coming soon 🚀

@guillaume-osmo
Copy link

guillaume-osmo commented Jun 12, 2025

so on M3 128 GB: generate at 89.3 +/- 0.1 tok/sec.
Why the first call prompt is always slower (model loading maybe) ?
I have Prompt first time 52.5 and then 162.2 +/-0.1 tok/sec.
using:

mlx_lm.generate --model microsoft/bitnet-b1.58-2B-4T  --prompt "implement bubble sort from scratch" --max-tokens 10 --temp 0.7

with --max-tokens 100 => 81 +/-0.2 tok/sec

@guillaume-osmo
Copy link

Q: Do you know where I can find a good example script in order to fine-tuned a Bitnet model or Falcon using mlx ?

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 12, 2025

Q: Do you know where I can find a good example script in order to fine-tuned a Bitnet model or Falcon using mlx ?

Not sure if this is implemented in MLX, but either for microsoft bitnet or Falcon-E you will need to download the pre-quantized weights (for bitnet it's in a separate repo: https://huggingface.co/microsoft/bitnet-b1.58-2B-4T-bf16 - and for Falcon-E it's on a separate revision prequantized.) Then use another implementation of BitLinear tailored for training.
The implementation of MS-Bitnet is here: https://github.com/huggingface/transformers/blob/27459025b8f77d53631f7961cc967fa659d43f7e/src/transformers/integrations/bitnet.py#L307 (you need to consider online_quant)
The implementation of Falcon-E bitnet layer is here: https://github.com/tiiuae/onebitllms/blob/main/src/onebitllms/layers/bitnet.py#L20

@awni awni force-pushed the pc/add-bitnet branch from 2ffcb79 to 00842d2 Compare July 2, 2025 20:28
@awni awni force-pushed the pc/add-bitnet branch from 00842d2 to 7e1666b Compare July 2, 2025 20:30
@awni
Copy link
Member

awni commented Jul 2, 2025

Ok so I posted some updates:

  • General cleanup
  • Removed support for Bitnet quantizing the Llama. Let's revisit that in a follow-on if we want to support that. But I want to get this landed and keep the diff self-contained
  • Improved the bitnet layer. There was some inefficient casing happening there so it's even a bit faster now.

This is good to merge. Thanks for the contributions everyone!!

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jul 2, 2025

My pleasure!

The llama changes are there to allow for Falcon bitnet. I can make a separate PR if you don’t mind.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented Jul 2, 2025

There was some inefficient casing happening there so it's even a bit faster now.

Thanks a lot, I noticed the speed up! (100-> 115 tok/s)

static_cast<T> fixes the bug I was having 🙌🏽

@awni
Copy link
Member

awni commented Jul 2, 2025

The llama changes are there to allow for Falcon bitnet. I can make a separate PR if you don’t mind.

Yes I gathered as much. You can send a PR.. I removed it from this one because the added complexity wasn't great. I don't yet have a good suggestions for how to do it in a simpler way.. but that's something that would be good to think through

@awni awni merged commit 5fa62eb into ml-explore:main Jul 2, 2025
4 checks passed
davidkoski pushed a commit to ml-explore/mlx-swift-examples that referenced this pull request Jul 3, 2025
* Port mlx-lm bitnet1.58 ml-explore/mlx-lm#219

* update: Update relu2 function to use compile for shapeless input

* refactor: Update BitLinear & Format code

* update: Add quantization parameters to BaseConfiguration

* update: Improve error handling during weight pre-loading in ContentView

* update: Add bitnet_b1_58_2b_4t_4bit model configuration to LLMModelFactory

* update: Rename relu2 function to reluSquared and refactor implementation

* update: ACKNOWLEDGMENTS.md

* Improve the bitnet kernel

* remove: eliminate reluSquared function from Bitnet.swift

* refactor: update kernel
dojoteef pushed a commit to dojoteef/mlx-lm that referenced this pull request Aug 1, 2025
* add bitnet

* update activation to relu2

* working bitnet

* remove artifacts

* remove logging

* add custom post quant

* fix dtype and add compile

* fixed weight unpack

* add custom kernel to avoid memory overhead

* compile relu2

* fix weight scale

* remove unused

* add tests and update tuner utils

* update acknowledgements

* add kernel caching

* add act_quant and set float16 as default dtype

* use mx.add and move scaling to kernel

* remove act quant

* move bitlinear layers to separate file

* feat: add falcon-e and other bitnet support

* refactor: address comments

* add support for 1.58bit N-bit quants

* 43.85% speedup in generation performance (M3 max)

* refactor utils

* remove masking (2% gen speed improvement)

* add quantization config

* test llama bitnet

* refactor apply_hf_quant

* default threadgroup: 64 -> 32

* add comment

* fix prompt processing perf

* remove modulo

* compile kernel in the constructor

* Improve the bitnet kernel

* remove benchmark

* refactor bitlinear swap

* format

* remove llama changes

* revert utils

* faster + cleanup

* not trainable

* fix tests

---------

Co-authored-by: younesbelkada <younes.belkada@tii.ae>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
davidkoski pushed a commit to ml-explore/mlx-swift-lm that referenced this pull request Nov 3, 2025
* Port mlx-lm bitnet1.58 ml-explore/mlx-lm#219

* update: Update relu2 function to use compile for shapeless input

* refactor: Update BitLinear & Format code

* update: Add quantization parameters to BaseConfiguration

* update: Improve error handling during weight pre-loading in ContentView

* update: Add bitnet_b1_58_2b_4t_4bit model configuration to LLMModelFactory

* update: Rename relu2 function to reluSquared and refactor implementation

* update: ACKNOWLEDGMENTS.md

* Improve the bitnet kernel

* remove: eliminate reluSquared function from Bitnet.swift

* refactor: update kernel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants