feat: Separate Q and KV dtypes for decode#286
Conversation
|
@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. |
|
Yes I do think you are on the right track, thank you!
I don't think so. |
|
@yzh119 The modified unit test passes for me, can you review and validate? |
There was a problem hiding this comment.
Hi @Yard1 , thanks so much for doing this and it look good to me in general.
I beg some other changes, mainly around BeginForward functions because it seems you assume we are using the same data type for q and kv and it might affect some resource estimation.
I left some suggested changes, besides them, you also need to separate qtype and kvtype in this function (pass the qtype also as an empty tensor):
flashinfer/python/flashinfer/decode.py
Lines 532 to 620 in 1250b68
and update
flashinfer/python/csrc/flashinfer_ops.h
Lines 77 to 80 in 1250b68
flashinfer/python/csrc/batch_decode.cu
Lines 120 to 188 in 1250b68
accordingly.
|
@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then |
|
Hi @Yard1 , I'm a little bit conservative here because this section of code flashinfer/include/flashinfer/attention/handler.cuh Lines 121 to 130 in 1250b68 might produce different |
|
Ok sounds good! Let me make the change. |
|
@yzh119 Updated, ptal! |
🤖 I have created a release *beep* *boop* --- ## [0.1.0](v0.0.4...v0.1.0) (2024-06-20) ### Highlights * Support any GQA group size support for tensor-cores kernels. * Support any page size support for tensor-cores kernels. * Support CUDA-Graph for prefill/decode APIs. * Add an option to accelerate decode kernels with Tensor Cores. * Support custom attention mask. (https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor) * Support logits cap in Grok-1 models. * Fused GPU-sampling kernels: top-p, top-k, speculative verification. (https://docs.flashinfer.ai/api/python/sampling.html) * PyTorch wrapper of group-gemm cutlass kernels. (https://docs.flashinfer.ai/api/python/sampling.html) ### Acknowledgement We thank [@ibsidorenko](https://github.com/ibsidorenko), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@Yard1](https://github.com/Yard1) [@AgrawalAmey](https://github.com/AgrawalAmey), [@xuzhenqi](https://github.com/xuzhenqi), [@mgerstgrasser](https://github.com/mgerstgrasser), [@esmeetu](https://github.com/esmeetu), [@yz-tang](https://github.com/yz-tang), [@HSQ79815](https://github.com/HSQ79815), [@Qubitium](https://github.com/Qubitium), [@shreygupta2809](https://github.com/shreygupta2809), [@sighingnow](https://github.com/sighingnow), [@vinx13](https://github.com/vinx13), [@tqchen](https://github.com/tqchen), [@merrymercy](https://github.com/merrymercy), [@comaniac](https://github.com/comaniac) and many others for their contributions and helpful discussions for 0.0.5 release. ### Refactor * support any GQA group size for tensor-cores kernels ([#301](#301)) ([c111ca](c111ca6)) * support any page size for tensor-cores kernels ([#306](#306)) ([82fd8c](82fd8c7)) ### Features * add `use_tensor_cores` option to decode kernels to accelerate GQA ([#317](#317)) ([3b50dd5](3b50dd5)) * add group gemm operators ([#282](#282)) ([e08ba42](e08ba42)) * initial support of distributed operators ([#289](#289)) ([03553da](03553da)) * initial support of logits hook ([#298](#298)) ([ab1e2ad](ab1e2ad)) * Separate Q and KV dtypes for decode ([#286](#286)) ([5602659](5602659)) * support cuda graph for batched multi-query(prefill/append) attention ([#275](#275)) ([83ceb67](83ceb67)) * support cuda graph for batched multi-query(prefill/append) attention ([#277](#277)) ([24cc583](24cc583)) * support custom attention mask in prefill/append attention kernels ([#266](#266)) ([7304282](7304282)) * fused speculative sampilng kernels ([#259](#259)) ([cea2bb](cea2bb9)) * expose sampling APIs in pytorch ([#238](#238)) ([092902](0929023)) ### Performance Improvements * initial cuda graph support ([#256](#256)) ([7e9cc7f](7e9cc7f)) * split kv-cache for prefill/append kernels ([#310](#310)) ([f0bb0a3](f0bb0a3)) * use packed bit array for attention mask ([#308](#308)) ([3d43dc9](3d43dc9)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
Closes #285
Modified unit tests pass. May need some extra validation.