Improve QAT nvfp4 numerics#3050
Conversation
**Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in #2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. Details TBD. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3050
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fa4d9ee with merge base 5cbbd73 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. Details TBD. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 47019f4 Pull Request resolved: #3050
| self._test_quantize_api_against_ptq( | ||
| NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), | ||
| target_prepare_sqnr=12, | ||
| target_prepare_sqnr=target_prepare_sqnr, |
There was a problem hiding this comment.
can you explain this I would expect that per-tensor scaling would have less error and thus a higher sqnr? but also I dont know what prepare sqnr means?
There was a problem hiding this comment.
Update: they're both inf now. The per-tensor case wasn't inf before because we weren't simulating _addmm_nvfp4_dispatch in QAT, but instead just called F.linear on NVFP4Tensor.to_nvfp4().to_dtype(), which doesn't give the same numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d8f7eff Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d0120f0 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d0120f0 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: ecbff90 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: ecbff90 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: a707a59 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: a707a59 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
| @@ -444,12 +444,16 @@ def _infer_fake_quantize_configs( | |||
| elif isinstance(base_config, NVFP4InferenceConfig): | |||
| if NVFP4MMConfig.DYNAMIC: | |||
There was a problem hiding this comment.
nit: if this is a boolean, might be good to say NVFP4MMConfig.is_dynamic I think, although probably not relevant to this PR
There was a problem hiding this comment.
ok, can fix separately since this is a PTQ config
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
vkuzo
left a comment
There was a problem hiding this comment.
lg for B200, I think there should be a TODO somewhere to in the future add a path for non-B200s if they want to at least emulate nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 77f47b7 Pull Request resolved: #3050
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: bb1356c Pull Request resolved: #3050
Stack from ghstack (oldest at bottom):
Summary: Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
Unit tests:
End-to-end tests:
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
Wikitext:
perplexity than the quantized baseline
same as the quantized baseline