Skip to content

Add UIntxBitPackedTensor, UIntxWeightOnlyConfig, and Int8DynamicActivationUIntxWeightConfig#4082

Merged
jerryzh168 merged 17 commits intomainfrom
gh/jerryzh168/48/head
Mar 19, 2026
Merged

Add UIntxBitPackedTensor, UIntxWeightOnlyConfig, and Int8DynamicActivationUIntxWeightConfig#4082
jerryzh168 merged 17 commits intomainfrom
gh/jerryzh168/48/head

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Mar 13, 2026

Stack from ghstack (oldest at bottom):

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

  • UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
    and aten.linear/t/slice dispatch implementations
  • UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
  • Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
  • Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:

  • python test/prototype/test_uintx_bit_packed_tensor.py
  • Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
  • Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
  • Tests cover slice dim0/dim1 for tensor parallelism
  • Tests cover non-standard shapes (1024x1025)
  • Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

Addressing #3891

…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4082

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 62b4c14 with merge base 6e5ea54 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Mar 13, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: bbbe034
Pull Request resolved: #4082
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 13, 2026
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 13, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: 858bd16
Pull Request resolved: #4082
@jerryzh168 jerryzh168 added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 13, 2026
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 14, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: de8cbe1
Pull Request resolved: #4082
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 16, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: 7b8c09c
Pull Request resolved: #4082
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 16, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: 01018a9
Pull Request resolved: #4082
@jerryzh168 jerryzh168 requested review from jcaip and vkuzo March 16, 2026 20:31
Comment thread torchao/prototype/quantization/quant_api.py
Comment thread torchao/prototype/quantization/uintx/uintx_bit_packed_tensor.py Outdated
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 16, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: e0729aa
Pull Request resolved: #4082
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 16, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: cef13ad
Pull Request resolved: #4082
@jerryzh168 jerryzh168 requested a review from jcaip March 17, 2026 05:34
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Mar 17, 2026

@claude review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Mar 17, 2026

Claude finished @vkuzo's task in 6m 21s —— View job


  • Read all changed files
  • Analyze existing patterns and cross-references
  • Review test coverage
  • Provide detailed review feedback

Review Summary

This PR adds a v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using gemlite bit-packing and Triton GEMM kernels, along with two new quantization configs (UIntxWeightOnlyConfig and Int8DynamicActivationUIntxWeightConfig). The implementation is well-structured and follows established patterns in the codebase. Below are my findings.


Issues

1. gemlite_kwargs["meta_args"] serialization concernuintx_bit_packed_tensor.py:174-182

The meta_args value from gemlite_linear.get_meta_args() is stored inside a dict in tensor_attribute_names. If meta_args contains torch.Tensor objects, they won't survive __tensor_flatten__ / __tensor_unflatten__ correctly (the TensorSubclassAttributeJSONEncoder in safetensors can't handle tensors nested in dicts). The old GemliteAQTTensorImpl had the same pattern, so this isn't a regression — but it's worth confirming meta_args is purely scalar/tuple metadata to ensure torch.save/torch.load and safetensors work correctly. Consider adding a serialization round-trip test.

2. gemlite_kwargs is a shallow copy in sliceuintx_bit_packed_tensor.py:347

gemlite_kwargs = self.gemlite_kwargs.copy()

This is a shallow copy. Mutable nested objects (like meta_args if it's a list/dict) will be shared between the original and sliced tensor. If meta_args is immutable (tuple of scalars), this is fine. If it's a list or dict, mutations to meta_args on the sliced tensor would affect the original. Consider using copy.deepcopy if meta_args could be mutable, or document the assumption.

3. except Exception for gemlite importuintx_bit_packed_tensor.py:16

except Exception:
    gemlite = None

Using except Exception is overly broad — it would silently swallow real errors like SyntaxError in a partially installed package or CUDA initialization failures. The old code in gemlite_layout.py used bare except: (even worse), and the test file uses except ModuleNotFoundError (better). Consider using except (ImportError, ModuleNotFoundError): here for consistency with the test and to avoid masking real errors.

4. Hard .cuda() calls in from_hp and dequantizeuintx_bit_packed_tensor.py:127, 253

int_data = int_data.cuda()  # line 127
self.packed_weight.cuda()   # line 253

These force data onto cuda:0. If the user is working on a multi-GPU setup (e.g., cuda:1), this could move data to the wrong device. The from_hp method already has device = hp_tensor.device on line 118 — consider using .to(device) instead of .cuda() (or at minimum .cuda(device)) to preserve the original device index. Similarly in dequantize, use self.packed_weight.device to target the correct GPU.


Suggestions (non-blocking)

5. No serialization or torch.compile tests

The test file covers quantization, forward pass, SQNR, slicing, FQN support, and non-standard shapes — which is good. However, there are no tests for:

  • torch.save / torch.load round-trip
  • Safetensors serialization
  • torch.compile compatibility

These are important for production use. Other tensor subclasses in the codebase are typically tested with these. Consider adding them in a follow-up.

6. UIntxWeightOnlyConfig / Int8DynamicActivationUIntxWeightConfig not exported from __init__.py

The old GemliteUIntXWeightOnlyConfig is exported from torchao/quantization/__init__.py. The new configs are only importable via torchao.prototype.quantization.quant_api. If the intent is for these to be user-facing (even as prototype), consider adding them to the prototype package's __init__.py or documenting the import path.

7. Duplicated validation logic between configs and from_hp

Both UIntxWeightOnlyConfig.__post_init__ (line 140) and UIntxBitPackedTensor.from_hp (line 105) validate bit_width in [4, 8]. The from_hp method also validates group_size, packing_bitwidth, and dtype — but the configs don't. Consider moving all validation to __post_init__ so users get errors at config construction time rather than at quantization time. For example, the constraint that group_size must be None for bit_width=8 (line 110) could be enforced in the config.

8. uintx/__init__.py is empty (besides license)uintx/__init__.py

The __init__.py doesn't export anything. Consider adding imports of UIntxBitPackedTensor so that from torchao.prototype.quantization.uintx import UIntxBitPackedTensor works cleanly.

9. dequantize method correctness for 8-bit symmetricuintx_bit_packed_tensor.py:248-295

The dequantize method uses gemlite.bitpack.unpack_over_rows with dtype=torch.uint8. For 8-bit symmetric quantization (signed int8, range -128 to 127), unpacking as uint8 could produce incorrect values since the bit pattern interpretation differs. Verify this path is correct for signed 8-bit data, or add a test that calls dequantize() on an 8-bit quantized tensor and checks the SQNR.

10. Slice doesn't update meta_argsuintx_bit_packed_tensor.py:365-366

After slicing, in_features and out_features are updated in gemlite_kwargs, but meta_args is left unchanged from the original tensor. If meta_args encodes shape-dependent information (like tiling parameters), the sliced tensor's forward_functional call might behave incorrectly. The tests verify that slicing produces correct internal tensor shapes but don't test that a sliced tensor can actually run a forward pass. Consider adding a test that does model.weight = sliced_weight; model(x) to verify end-to-end correctness after slicing.


What looks good

  • The TorchAOBaseTensor subclass pattern is followed correctly, including the dtype convention (matching Int8Tensor, Float8Tensor)
  • The parameter_name kwarg in handler functions follows the established pattern for FQN support
  • Test coverage is good for the core functionality (multiple bit widths, group sizes, packing widths, dynamic activation, slicing, FQN, non-standard shapes)
  • The deprecation warning on the old GemliteUIntXWeightOnlyConfig is a nice touch
  • The aten.t.default no-op implementation correctly handles the F.linear decomposition to t + mm
  • Use of return_and_correct_aliasing is consistent with other implementations

Comment thread torchao/prototype/quantization/uintx/uintx_bit_packed_tensor.py Outdated
Comment thread torchao/prototype/quantization/quant_api.py
Comment thread torchao/prototype/quantization/quant_api.py
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 18, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: 90da320
Pull Request resolved: #4082
@jerryzh168 jerryzh168 requested a review from vkuzo March 18, 2026 04:25
| mxfp4 | mxfp4 | {class}`~torchao.prototype.mx_formats.MXDynamicActivationMXWeightConfig`(prototype): Applies mxfp8 or mxfp4 dynamic quantization to activations and weights. Requires NVIDIA SM100+ (Blackwell) or AMD MI350+. |
| intx | bf16 | {class}`~torchao.quantization.IntxWeightOnlyConfig`: Applies intx (1-8 bit) weight-only quantization. Supports groupwise and per-channel. Works with Linear and Conv2D. |
| intx | int8 | {class}`~torchao.quantization.Int8DynamicActivationIntxWeightConfig`: Applies int8 dynamic per-token activation and intx (1-8 bit) weight quantization. CPU optimized. |
| uintx (4/8-bit) | bf16 | {class}`~torchao.prototype.quantization.UIntxWeightOnlyConfig`(prototype): Applies 4-bit (asymmetric, grouped) or 8-bit (symmetric, per-channel) weight-only quantization using gemlite (https://github.com/dropbox/gemlite) Triton kernels. Supports packing bit widths 8, 16, 32. Requires CUDA and gemlite. optimized for A100 and H100 GPUs. |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great, thank you!

Copy link
Copy Markdown
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! looks good if CI is green

…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 19, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: 18a33f3
Pull Request resolved: #4082
…onfig, and Int8DynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
…ynamicActivationUIntxWeightConfig"

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 19, 2026
…ationUIntxWeightConfig

Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using
gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based
GemliteUIntXWeightOnlyConfig path.

- UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(),
  and aten.linear/t/slice dispatch implementations
- UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit)
- Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight
- Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes

Test Plan:
- python test/prototype/test_uintx_bit_packed_tensor.py
- Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8)
- Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos
- Tests cover slice dim0/dim1 for tensor parallelism
- Tests cover non-standard shapes (1024x1025)
- Verified backward compat: old GemliteUIntXWeightOnlyConfig still works

ghstack-source-id: c8e4e7e
Pull Request resolved: #4082
@jerryzh168 jerryzh168 changed the base branch from gh/jerryzh168/48/base to main March 19, 2026 20:14
@jerryzh168 jerryzh168 merged commit 5ee9094 into main Mar 19, 2026
39 checks passed
@jerryzh168
Copy link
Copy Markdown
Contributor Author

merging for now. we can also add a uintx_packing_format arg for UIntxWeightOnlyConfig when it's needed I think. currently it defaults to bit packed (in gemlite's flavor).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants