[Feature] Support fp8 e5m2 kv cache with flashinfer#1204
[Feature] Support fp8 e5m2 kv cache with flashinfer#1204merrymercy merged 7 commits intosgl-project:mainfrom
Conversation
|
Nice work! I'll review it asap. May we also support FP8 E4M3? |
FP8 E4M3 needs scale factor and calibration. We may add it in the future. |
| if self.server_args.kv_cache_dtype == "auto": | ||
| self.kv_cache_dtype = self.dtype | ||
| elif self.server_args.kv_cache_dtype == "fp8_e5m2": | ||
| if self.server_args.disable_flashinfer or self.server_args.enable_mla: |
There was a problem hiding this comment.
Currently, only FlashInfer is supported and not Triton, due to the issue of insufficient smem. This needs to be fixed in another PR.
| if cache_v.dtype != self.dtype: | ||
| cache_v = cache_v.to(self.dtype) | ||
| if self.store_dtype != self.dtype: | ||
| self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) |
There was a problem hiding this comment.
workaround for float8_e5m2
Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
|
Sorry to dig this up but - are we suggesting that fp8 kv cache increased accuracy in both mmlu and gsm8k? Are we sure we don't have those values in the table reversed? |
|
@qeternity In the previous evaluation, I tested gsm8k with 200 questions (default setting in the benchmark script), so the result may not be reliable enough. I just test all the datasets and update the result in the table. |
Co-authored-by: Yineng Zhang <me@zhyncs.com>
Motivation
Support fp8 e5m2 kv cache with flashinfer.
Usage
Add
--kv-cache-dtype fp8_e5m2to enable this feature. Currently it only works when flashinfer is not disabled.Performance & Accuracy
Tested with llama2-13b-chat on A100, the throughput increased by 17.8% without accuracy degradation.
Reproduce
The performance boost is model dependent. llama3-8b was also tested, but the performance was not improved.