Improve FSDP support for low-bit optimizers#538
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/538
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6cec214 with merge base e8662e0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| @property | ||
| def block_size(self): | ||
| return self.codes.numel() * 2 // self.scale.numel() | ||
| self.block_size = codes.numel() * 2 // scale.numel() |
There was a problem hiding this comment.
curious q: Is there some description of the codes/ scales tensor and their relation to each other?
I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful
There was a problem hiding this comment.
I will add some description. Basically for 8-bit and FP8, codes has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size). This is done to relax the requirement that the last dimension must be divisible by block_size -> now we only need numel (total size) to be divisible by block_size. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale is always a 1D tensor, with size=original_tensor.numel() // block_size
There was a problem hiding this comment.
@drisspg Added some docs. Lmk if it is still unclear.
|
The |
|
I don't use FSDP at work or personal projects because I don't have access to multi-GPU machines, so can't really answer your question 😅. Only added FSDP support for low bit optimizers due to request from people. At least in torchtune, I saw that |
| ) | ||
|
|
||
| def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): | ||
| """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507. |
There was a problem hiding this comment.
btw the link is not valid, can you remove . in the end?
| ) | ||
|
|
||
| def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): | ||
| """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861. |
msaroufim
left a comment
There was a problem hiding this comment.
Nice! We should look to get the low bit optimizers out of prototype soon!
* fix generate for llama3 * switch more things to C * remove C++ header
* make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <mikekg@meta.com> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in> Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: lucylq <lfq@meta.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
* code beautification * code beautification, move functions together * make --device fast the default (pytorch#515) * make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <mikekg@meta.com> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in> Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: lucylq <lfq@meta.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com> * add unpacking support (pytorch#525) * add unpacking support * fix typos and linter * perform parallel prefill when possible (pytorch#568) * perform parallel prefill when possible * typo * disable hack * remove print * remove debug messages which prevent export * fixes * stream results in generate.py (#571) * remove logging interfering with export --------- Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in> Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: lucylq <lfq@meta.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
DTensor.from_local(run_check=False)to wrap quantized optim state (instead of swapping_local_tensor)block_sizea fixed attribute, calculated inside__init__(instead of dynamically calculate every time)all_gather_into_tensorandwait_tensorto supportDTensor.full_tensor()