Add TTFT benchmarks + update sparsity benchmarks#1140
Conversation
Summary: This PR adds in a sparsity option to the LLaMa benchmarks. Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1140
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit de2d447 with merge base 2f97b09 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| from torchao.dtypes import MarlinSparseLayout | ||
| quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) | ||
| if sparsity and "semi" in sparsity: | ||
| quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) |
There was a problem hiding this comment.
this isn't using any of the derived variables. It should use the derived ones or be in a separate section.
HDCharles
left a comment
There was a problem hiding this comment.
lgtm if you move the marlin stuff so its clearer what derived variables it actually uses
| if write_result: | ||
| result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " | ||
| result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " | ||
| result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, time={t:5.4f} sec, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " |
There was a problem hiding this comment.
time is a really generic term, is this TTFT or overall run? the tok/s info is already the non prefill indicator so TTFT or time to do prefill is probably more valuable.
There was a problem hiding this comment.
It's overall time, but I limit num_tokens to 1. I can make this a bit clearer though, maybe a --ttft flag that sets forces num_tokens_generated to be 1.
6858180 to
4fdfa7b
Compare
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
|
Hi @vkuzo Thanks for the great work! |
* Torchchat CLI pipeline for Multimodal Models * Remove torchaudio check; we don't use it * Flip the imports back for ET --------- Co-authored-by: vmpuri <puri@meta.com> Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
It's available as of last night! |
|
Thanks! |
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available.
Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.