Skip to content

Conversation

@mryab
Copy link
Collaborator

@mryab mryab commented Feb 21, 2023

As mentioned in huggingface/transformers#20247 and as often requested by users of https://github.com/bigscience-workshop/petals, it would be very useful to save and load 8-bit quantized parameters of Linear8bitLt (e.g., to save bandwidth when downloading large models only for inference). This PR implements that feature by adding the missing SCB set of float16 quantized weights to the state dict of the layer.

Importantly,

  1. One needs to first trigger the quantization by calling layer.cuda()
  2. To load the weights into the model, one also needs to call layer.cuda() before that to allocate the necessary buffers
  3. If has_fp16_weights is True, then 8-bit quantized weights are not stored in the state — thus not serialized
  4. Loading 8-bit weights into a layer with has_fp16_weights==True is not supported, as it would require dequantizing the weights

The test I've implemented handles all combinations of use cases and verifies that loading the serialized weights either results in an exception (if rule 2 is violated) or the same output as before serialization.

cc @borzunov @younesbelkada

@mryab mryab changed the title [WIP] Implement proper serialization of Linear8bitLt Implement proper serialization of Linear8bitLt Feb 25, 2023
@mryab mryab marked this pull request as ready for review February 25, 2023 15:03
@mryab mryab requested a review from TimDettmers February 25, 2023 15:03

if state.tile_indices is None:
order, tile_size = state.formatB, state.get_tile_size()
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a slight possibility of a bug before: .cuda() might transfer the tensor to GPU0 instead of the GPU we're using for weights

@younesbelkada
Copy link
Collaborator

Thanks a mile for this PR, definitely looking forward to merge this as it will help push 8bit weights on the Hub as well, as requested by the community
I will try this PR whenever I can on a dummy int8 model using transformers

return (8, 32) if self.formatB == "col_turing" else (32, 32)

@property
def tile_indices(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use cached_property?

Copy link
Collaborator Author

@mryab mryab Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, it is only available since Python 3.8: hence, the code won't work on systems with older Python. Still, thanks for the suggestion — I haven't actually thought of it when working on the PR!

Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this !
I tried it and it works on transformers, this PR would unlock the possibility to push 8bit models on the Hub for transformers users
huggingface/transformers#22177

Copy link
Collaborator

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Max! This is an awesome PR with a great feature. It almost looks ready to merge for me. I think there is one case that is currently missing though, and that is if a user has a GPU with compute capability less than 7.5 (GTX 1080 or older) only .CB is not None and .CxB is None.

You can create a test case for this with a >= 7.5 GPU by using the state.force_no_igemmlt variable. See this Linear8bitLt test case for an example. Could you check your implementation for this case, make the necessary changes and report back? Thank you!

@TimDettmers
Copy link
Collaborator

Thank you, Max, for this PR. If the test passes, this looks good to me. I will merge and push out a release soon.

@TimDettmers TimDettmers merged commit ed6f3eb into main Apr 11, 2023
@mryab
Copy link
Collaborator Author

mryab commented Apr 11, 2023

Hi @TimDettmers, thanks for the review! I've implemented the force_no_igemmlt flag as you requested and ran the tests on a 2080 machine: there are some combinations of test cases that fail, mainly if force_no_igemmlt differs before and after the serialization:

================================================================================================== short test summary info ===================================================================================================
FAILED tests/test_linear8bitlt.py::test_linear_serialization[False-False-False-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[False-False-False-True-False] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[False-True-False-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[False-True-False-True-False] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-False-False-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-False-False-True-False] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-False-True-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-False-True-True-False] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-True-False-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-True-False-True-False] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-True-True-False-True] - AssertionError: assert False
FAILED tests/test_linear8bitlt.py::test_linear_serialization[True-True-True-True-False] - AssertionError: assert False
=============================================================================================== 12 failed, 22 passed in 6.48s ================================================================================================

As such, I think that users with older GPUs might face some inconsistencies in behavior. I was thinking of adding an extra check during serialization along the lines of what you described: we can declare the feature to be a prototype and release a follow-up fix if necessary though.

@TimDettmers TimDettmers deleted the serialize_8bit branch August 5, 2023 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants