Skip to content

[WOQ] Add XPU kernel for _weight_int8pack_mm#160938

Closed
xiaowangintel wants to merge 2 commits intopytorch:mainfrom
xiaowangintel:xw/int8_woq
Closed

[WOQ] Add XPU kernel for _weight_int8pack_mm#160938
xiaowangintel wants to merge 2 commits intopytorch:mainfrom
xiaowangintel:xw/int8_woq

Conversation

@xiaowangintel
Copy link
Contributor

@xiaowangintel xiaowangintel commented Aug 19, 2025

Summary:
This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA.

Motivation:
Same as #159325.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160938

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit bcdb3c0 with merge base 90f50f7 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: inductor (aoti) labels Aug 19, 2025
@xiaowangintel xiaowangintel changed the title [WOQ] Add XPU kernel for _weight_int8pack_mm [WIP][WOQ] Add XPU kernel for _weight_int8pack_mm Aug 19, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@xiaowangintel
Copy link
Contributor Author

@liangan1 @ZhiweiYan-96 @guangyey @EikanWang Please help to review this pr.

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 19, 2025
Comment on lines +571 to +582
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");
TORCH_CHECK(
A.stride(1) == 1,
__func__,
" : A must be contiguous on the last dimension.");
TORCH_CHECK(B.dtype() == kChar, __func__, " : expect B to be int8 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(B.size(1) == K, __func__, " : expect B.size(1) == ", K);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__func__ has been included in TORCH_CHECK. So, __func__ could be removed here.

A.contiguous(),
1.0,
0,
B.contiguous(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
B.contiguous(),
B

// --- Launch kernel ---
Tensor bias = at::Tensor();
Tensor mat2_zero_points = at::Tensor();
Tensor non_const_scales = scales;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no more operation on non_const_scales, why not use scales directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quantized_matmul receive weight scales as lvalue reference. However, scales is const Tensor&, and cause C++ compilation errors.

@liangan1
Copy link
Contributor

Generally LGTM.

Copy link
Collaborator

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

" : expect scales to be 1d tensor with size ",
N);

auto C = at::empty({M, N}, A.options());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to invoke native::empty directly?

@EikanWang EikanWang requested a review from drisspg August 19, 2025 03:02
@liangan1
Copy link
Contributor

@jerryzh168 can you help to review this PR?

@liangan1
Copy link
Contributor

liangan1 commented Sep 1, 2025

@xiaowangintel pls rebase the code and fix the CI issue.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Sep 1, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Sep 1, 2025
@xiaowangintel xiaowangintel changed the title [WIP][WOQ] Add XPU kernel for _weight_int8pack_mm [WOQ] Add XPU kernel for _weight_int8pack_mm Sep 2, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Sep 3, 2025
@etaf etaf added the ciflow/xpu Run XPU CI tasks label Sep 3, 2025
@liangan1
Copy link
Contributor

liangan1 commented Sep 8, 2025

@drisspg can you help to review this PR?

@xiaowangintel
Copy link
Contributor Author

@jerryzh168 can you help to review this PR?

@guangyey guangyey requested a review from jerryzh168 September 19, 2025 01:55
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't provide meaningful reviews as I'm not familiar with hardware details, but can stamp.

also should this op live in torchao in the end?

@liangan1
Copy link
Contributor

can't provide meaningful reviews as I'm not familiar with hardware details, but can stamp.

also should this op live in torchao in the end?

Thanks Jerry. Yes. This op will be used to speedup the WOQ-INT8 in the torchAO.

@xiaowangintel
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Summary:
This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA.

Motivation:
Same as pytorch#159325.

Pull Request resolved: pytorch#160938
Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Summary:
This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA.

Motivation:
Same as pytorch#159325.

Pull Request resolved: pytorch#160938
Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Summary:
This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA.

Motivation:
Same as pytorch#159325.

Pull Request resolved: pytorch#160938
Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
Summary:

Supports woq_int8 inductor pattern on Intel GPU. When using torch.compile, woq_int8 will be lowering to _weight_int8pack_mm instead of being falled back mul().sum(). The Intel GPU backend of _weight_int8pack_mm was supported in #160938.

Pull Request resolved: #163615
Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/desertfire, https://github.com/jansel
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
Summary:

Supports woq_int8 inductor pattern on Intel GPU. When using torch.compile, woq_int8 will be lowering to _weight_int8pack_mm instead of being falled back mul().sum(). The Intel GPU backend of _weight_int8pack_mm was supported in #160938.

Pull Request resolved: #163615
Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/desertfire, https://github.com/jansel
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
…ch#163615)

Summary:

Supports woq_int8 inductor pattern on Intel GPU. When using torch.compile, woq_int8 will be lowering to _weight_int8pack_mm instead of being falled back mul().sum(). The Intel GPU backend of _weight_int8pack_mm was supported in pytorch#160938.

Pull Request resolved: pytorch#163615
Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/desertfire, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: inductor (aoti)

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

9 participants