Refactor custom FPx cast#363
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/363
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bd64efc with merge base 664f073 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2x is a sizeable regression, how about keeping the LUT for the formats we already have it for and having a generic fallback for the other formats? People can then optimize format by format individually if they want. |
|
@vkuzo I have updated the dequant denormal implementation. No speed regression anymore (I updated the results in the 1st post). Didn't need to use the hard-coded LUT from your implementation. If torch compiler does constant folding and loop unrolling properly, I think my implementation should match your previous implementation exactly. If possible, you can benchmark on your GPUs to make sure 100% there is no regression. |
|
Here are results on an H100: https://gist.github.com/vkuzo/324256b8defd0231852a23cbb34f49a6, I see no meaningful change in performance, awesome stuff |
|
|
||
| def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: | ||
| """ | ||
| TODO(future): check if LUT for everything is faster than bit shifting, |
There was a problem hiding this comment.
is this comment still relevant?
maybe add a docblock?
There was a problem hiding this comment.
using LUT for everything in dequant might be faster, like current NF4 implementation. I haven't benchmarked so I'm not sure.
I didn't add a docblock here since I think this is kinda an internal function. But a simple doc won't hurt. Will add some doc for this and quant function above. I already added a short description for these 2 functions at the top of the file.
| F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) | ||
|
|
||
|
|
||
| def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: |
Summary:
This diff
- refactors install_et.sh into a bunch of utils
- Uses those utils in build_android.sh to minimize duplication.
- Makes sure taht we are building with custom sdpa op
Test Plan:
Model export
python export.py --quant '{"linear:a8w4dq" : {"groupsize": 256}}'
--checkpoint-path /home/kimishpatel/models/llama2/stories/stories110M.pt
--params-path /home/kimishpatel/models/llama2/stories/params.json
--output-pte-path /tmp/stories110m_a8w4dq.pte
python utils/tokenizer.py --tokenizer-model=/tmp/tokenizer.model
linux:
./scripts/install_et.sh
rm -rf build/cmake-out/
cmake -S ./runner-et -B build/cmake-out -G Ninja
cmake --build ./build/cmake-out
./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z /tmp/tokenizer.bin -t 0 -n 120
android:
./runner-et/build_android.sh
adb push ./build/cmake-out-android/runner_et /data/local/tmp/
adb push /tmp/stories110m_a8w4dq.pte /data/local/tmp/
adb push /tmp/tokenizer.bin /data/local/tmp/
adb shell "cd /data/local/tmp && ./runner_et ./stories110m_a8w4dq.pte -z
./tokenizer.bin -t 0 -n 120"
Will add build commands to ci in the next PR
Reviewers:
Subscribers:
Tasks:
Tags:
Closes #354
TODO:
Check torch.compileBenchmark before and after8841094 (main)
2690b92 (this PR)
Dequant is 2x slower because I replaced LUT-based denormal handling with a more generic logic. @vkuzo Should I add back the LUT-based logic (check specifically for E2M3 E3M2 E2M1)? If we are interested in performance then perhaps we can generate a LUT for all bit patterns and cache it.
UPDATE
95f4582 (this PR v2)
Now FP4_E2M1 is slower lol. Feel like this should be bandwidth-limited. It might be register-limited also? Will do some profiling + make sure torch.compile run optimally. Interesting that native PyTorch float8 dequant is slower.
UPDATE 2
dcd5a05 (this PR v3)
Speed recovered 😊