Skip to content

Align Int4Tensor implementation details with the design of Float8Tensor#2687

Merged
jerryzh168 merged 1 commit intomainfrom
jerryzh168/stack/16
Aug 12, 2025
Merged

Align Int4Tensor implementation details with the design of Float8Tensor#2687
jerryzh168 merged 1 commit intomainfrom
jerryzh168/stack/16

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Aug 5, 2025

Stacked PRs:


Align Int4Tensor implementation details with the design of Float8Tensor

Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)

Note: This is just refactoring Int4Tensor, no BC related changes in this PR

Int4Tensor path is exposed in version 2 of Int4WeightOnlyConfig (default version is still 1, which is using the old AQT path

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2687

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ba62d8e with merge base c086ade (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1d84542 to 4874773 Compare August 5, 2025 03:25
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2025
@jerryzh168 jerryzh168 added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 5, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 18:39
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 4874773 to 1beccb0 Compare August 5, 2025 18:39
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 18:39
Comment thread test/quantization/quantize_/workflows/int4/test_int4_tensor.py Outdated
Comment thread test/quantization/quantize_/workflows/int4/test_int4_tensor.py Outdated
Comment thread test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Comment thread torchao/quantization/quantize_/workflows/int4/int4_tensor.py Outdated
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
input_tensor,
weight_tensor._data.contiguous(),
weight_tensor.qdata.contiguous(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected that the tensors are not contiguous? if not, can we assert for this instead of calling contiguous?

Copy link
Copy Markdown
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the non-contiguous comes from the reshape ops like transpose, view I think, but the kernel will need these to be contiguous, I can try changing these to assert and do the contiguous operation in user side to see if it works

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected the weights to be stored in a format aligned with what the kernel needs, without any need for just-in-time layout transforms. Does this match how the current code works?

Copy link
Copy Markdown
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally it is, but the weights also goes through some transformations like the ones we listed in test_moe_weight_reshape_ops which makes weight / scale etc. non-contiguous I think, but I can try to do call contiguous in user code, that might be cleaner I think

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out the contiguous is not implemented properly, just fixed that and we can remove contiguous calls in linear/bmm now

Comment thread torchao/quantization/quantize_/workflows/int4/int4_tensor.py
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 23:30
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1beccb0 to 5f6306e Compare August 5, 2025 23:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 23:30


@register_quantize_module_handler(TestOnlyMoEQuantConfig)
def moe_quant_fn(module, config: TestOnlyMoEQuantConfig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really confusing, could you share the result of print(model) after this function has been applied?

if it's going to print model with parameters wrapped in Int4Tensor, can we just wrap the parameters directly without all of these layers of abstraction?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is working around the fact that quantize_ needs to work on modules, IMO we should change quantize_ to handle this instead of working around? seems important for MoEs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the parameters are wrapped in Int4Tensor, this is just applying quantization to each of the moe weights: w1, w2 and w3

I can inline these for now. can follow up with how to have an API for weights + configs separately

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not worth changing API right now since MoE quant is also moving, let me know if current code looks good

Comment thread test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Comment thread torchao/quantization/quantize_/workflows/int4/int4_tensor.py Outdated
Comment thread torchao/testing/model_architectures.py Outdated
return model, input_data


class Experts(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call it something like FeedForwardWithExperts? Experts is ambiguous

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is adapted from https://github.com/meta-llama/llama-models/blob/a9c89c471f793423afd4cc3ca8671d6e56fe64cb/models/llama4/moe.py#L22, how about renaming to LLama4Experts to make it more specific

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 01:07
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 5f6306e to 6bd3106 Compare August 6, 2025 01:08
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/10 branch from 7868bcf to ceac84c Compare August 7, 2025 02:58
jerryzh168 added a commit that referenced this pull request Aug 7, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from aee3dbb to 4a50bf7 Compare August 7, 2025 02:58
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 03:37
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 4a50bf7 to 7a21719 Compare August 7, 2025 03:37
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 03:37
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 03:51
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 7a21719 to 0040a5f Compare August 7, 2025 03:51
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 03:51
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 04:29
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 0040a5f to f9695e4 Compare August 7, 2025 04:29
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 04:29
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 7, 2025 20:56
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from f9695e4 to e84a76d Compare August 7, 2025 20:56
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 7, 2025 20:56
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/10 branch from 8fb7215 to 5bb2fd4 Compare August 7, 2025 23:08
jerryzh168 added a commit that referenced this pull request Aug 7, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from e84a76d to 8983652 Compare August 7, 2025 23:08
jerryzh168 added a commit that referenced this pull request Aug 7, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 8983652 to abcddcf Compare August 7, 2025 23:21
Comment thread torchao/quantization/quantize_/workflows/int4/int4_tensor.py
tensor_data_attrs = ["_data", "scale", "zero_point"]
tensor_attributes = ["block_size", "shape"]
tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_attribute_names is a lil weird to me
are these attributes that are tensors and thus should go int he right unflatten location?

Copy link
Copy Markdown
Contributor Author

@jerryzh168 jerryzh168 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah.. this somewhat follows what unflatten and flatten functions names these things, the tensor means the tensor subclass instance, meaning the attributes of the tensor subclass instance, instead of "tensor attributes"

I could remove tensor_ as well to make it less confusing? probably better to do in a separate PR

Comment thread torchao/quantization/quantize_/workflows/int4/int4_tensor.py Outdated
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestInt4Tensor(TestCase):
class TestInt4Tensor(TorchAOIntegrationTestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we keeping both old and new? If we are keeping the old version working, I would expect this test case to not have any changes, as it would test the old version.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old version meaning the version using AQT? we are keeping AQT, but this test does not test the AQT path, it only tests the new Int4Tensor, and we are updating Int4Tensor in this PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update the PR summary with context on this? Migrations are always confusing and clearly laying out what is changing with BC, what is breaking BC, and what is not changing will help get a good review.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK added, no BC related changes in this PR

Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* defined `tensor_data_names` and `tensor_attribute_names` so we can remove some of the implementations from TorchAOBaseTensor
* Migrated op implementation and tests from #2387

Note: This is just refactoring Int4Tensor, no BC related changes in this PR

Int4Tensor path is exposed in version 2 of `Int4WeightOnlyConfig` (default version is still 1, which is using the old AQT path

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants