-
-
Notifications
You must be signed in to change notification settings - Fork 814
Implement proper serialization of Linear8bitLt #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| if state.tile_indices is None: | ||
| order, tile_size = state.formatB, state.get_tile_size() | ||
| transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a slight possibility of a bug before: .cuda() might transfer the tensor to GPU0 instead of the GPU we're using for weights
|
Thanks a mile for this PR, definitely looking forward to merge this as it will help push 8bit weights on the Hub as well, as requested by the community |
| return (8, 32) if self.formatB == "col_turing" else (32, 32) | ||
|
|
||
| @property | ||
| def tile_indices(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use cached_property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, it is only available since Python 3.8: hence, the code won't work on systems with older Python. Still, thanks for the suggestion — I haven't actually thought of it when working on the PR!
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this !
I tried it and it works on transformers, this PR would unlock the possibility to push 8bit models on the Hub for transformers users
huggingface/transformers#22177
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Max! This is an awesome PR with a great feature. It almost looks ready to merge for me. I think there is one case that is currently missing though, and that is if a user has a GPU with compute capability less than 7.5 (GTX 1080 or older) only .CB is not None and .CxB is None.
You can create a test case for this with a >= 7.5 GPU by using the state.force_no_igemmlt variable. See this Linear8bitLt test case for an example. Could you check your implementation for this case, make the necessary changes and report back? Thank you!
|
Thank you, Max, for this PR. If the test passes, this looks good to me. I will merge and push out a release soon. |
|
Hi @TimDettmers, thanks for the review! I've implemented the As such, I think that users with older GPUs might face some inconsistencies in behavior. I was thinking of adding an extra check during serialization along the lines of what you described: we can declare the feature to be a prototype and release a follow-up fix if necessary though. |
As mentioned in huggingface/transformers#20247 and as often requested by users of https://github.com/bigscience-workshop/petals, it would be very useful to save and load 8-bit quantized parameters of
Linear8bitLt(e.g., to save bandwidth when downloading large models only for inference). This PR implements that feature by adding the missingSCBset of float16 quantized weights to the state dict of the layer.Importantly,
has_fp16_weightsis True, then 8-bit quantized weights are not stored in the state — thus not serializedhas_fp16_weights==Trueis not supported, as it would require dequantizing the weightsThe test I've implemented handles all combinations of use cases and verifies that loading the serialized weights either results in an exception (if rule 2 is violated) or the same output as before serialization.
cc @borzunov @younesbelkada