adding whole Linear8bitLt/Linear4bit module save/load serialization#1099
adding whole Linear8bitLt/Linear4bit module save/load serialization#1099Titus-von-Koeller merged 1 commit intobitsandbytes-foundation:mainfrom
Conversation
|
cc @Titus-von-Koeller wdyt? might be good to have for the next release no? |
|
Yes, I agree, this is looking good and should be merged before the release. I'll review it more in depth soon. Thanks @rdyro for the good work and taking the initiative to contribute, really appreciated 🤗 |
|
Thanks for the positive feedback! I really like your work with Let me know if ideally, the new tests should extend to all Linear layers, not just |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Dear @rdyro, I just reviewed your proposed changes and everything really looks good! I don't think any additional tests are needed, what you did already looks good the way it is. I also ran the Thanks so much for your contribution and if you feel like contributing more, we'd be happy to support you! |
The purpose of this pull request is to allow
torch.save/torch.loaddirectly on modules containingLinear4bitandLinear8bitLtsubmodules.Currently,
torch.save, thentorch.loadonLinear8bitLt(after first forward) causes a missing fieldCBerror in theInt8Paramsclass. This PR makes torch aware of theCBandSCBfields inInt8Paramsclass.The core of this PR is
in
class Int8ParamsI also added the
torch.save->torch.loadtest to theLinear4bit(this was already working) andLinear8bitLt(this is not yet working).While saving modules directly in Pytorch with
saveandloadis not good practice, the change to make this work is minimal and makes disk caching modules for development easier.