Skip to content

add fp8 scaled_mm for XPU#140972

Closed
yuchengliu1 wants to merge 22 commits intopytorch:mainfrom
yuchengliu1:scaled_mm_xpu
Closed

add fp8 scaled_mm for XPU#140972
yuchengliu1 wants to merge 22 commits intopytorch:mainfrom
yuchengliu1:scaled_mm_xpu

Conversation

@yuchengliu1
Copy link
Contributor

@yuchengliu1 yuchengliu1 commented Nov 18, 2024

This PR introduce fp8 scaled_mm for intel XPU. The UT test\xpu\test_scaled_mm.py is a subset of test\test_matmul_cuda.py.

torch-xpu-ops fallback scaled_mm now. So we will get a warning. This warning has not influence on this op.

image

There will be a PR to remove the fallback in torch-xpu-ops after this PR merged.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @kwen2501 @c-p-i-o @yf225 @ColinPeppler @desertfire

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Nov 18, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140972

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 3 Unrelated Failures

As of commit 44bbed8 with merge base 09e5a93 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration release notes: releng release notes category labels Jan 19, 2025
@yuchengliu1 yuchengliu1 changed the base branch from main to gh/yanbing-j/28/head January 19, 2025 09:28
@etaf etaf added the ciflow/xpu Run XPU CI tasks label Jan 20, 2025
engine,
mat1.data_ptr());
auto mat2_c = mat2.contiguous();
dnnl::memory weight = at::native::onednn::make_onednn_memory(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stride inputs is not supported by oneDNN?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, delete the line 280, and use the tensor stride in onednn memory desc.

if (with_bias) {
args.insert({DNNL_ARG_BIAS, onednn_bias});
}
// auto sycl_queue = dnnl::sycl_interop::get_queue(stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

at::native::resize_output(out, {mat1.size(0), mat2.size(1)});
onednn::scaled_matmul(
out, mat1, mat2, bias, scale_a, scale_b, onednn::Attr());
return out;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch._scaled_mm support out_dtype and oneDNN also support different output dtype https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html#data-types, suggest to to add the related functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function only check the input out_dtype is same to the dtype of out tensor. It has no restrictions on out_dtype

@yuchengliu1 yuchengliu1 changed the base branch from gh/yanbing-j/28/head to main March 13, 2025 06:37
@pytorch-bot pytorch-bot bot added module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: quantization release notes category and removed ciflow/xpu Run XPU CI tasks labels Mar 13, 2025
@guangyey guangyey marked this pull request as ready for review March 18, 2025 02:21
@guangyey guangyey marked this pull request as draft March 18, 2025 02:21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of current platform? It is related to CPU, GPU, os, or oneDNN version? It is about hardware or software?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This try/catch is to avoid onednn unimplemented the problem config. It may because onednn have not implemented yet or the hardware is not support.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on your implementation, my understanding is that removing this try-catch block might cause oneDNN to crash on certain GPUs—possibly on integrated GPUs? Is that correct?
Could oneDNN confirm which version will provide an API to query to support FP8 primitives?

@yuchengliu1 yuchengliu1 marked this pull request as ready for review March 26, 2025 10:08
@soulitzer soulitzer requested a review from vkuzo March 26, 2025 22:47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should not import torchao here, as torchao has a dependency on PyTorch. All these functions in torchao are also not in the BC surface so we really should not use them in other repos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed torchao in test

@EikanWang
Copy link
Collaborator

@yuchengliu1 , could you pls. help investigate the failures?

@yuchengliu1
Copy link
Contributor Author

I noticed that _scaled_mm is already registered to fallback to CPU in https://github.com/intel/torch-xpu-ops/blob/7e51233d26717f5d4d402685786a0f4c2aa4198e/src/ATen/native/xpu/XPUFallback.template#L212, but this PR re-registered it resulting in CI failure. So I think you need to remove this OP from torch-xpu-ops and update the commit pin within this PR.

Yes, there has been already a PR in torch-xpu-ops.

@yuchengliu1
Copy link
Contributor Author

@yuchengliu1 , could you pls. help investigate the failures?

Most of these failures due to the fallback to CPU. The CI run CPU scaled_mm actually. We need find a way to remove the fallback firstly and then trigger a CI to test the XPU scaled_mm.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 5, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 5, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 7, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 7, 2025
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just nit in try-catch handle, otherwise, LGTM.
Add @ZhiweiYan-96 in case you have some comments here

@guangyey guangyey requested a review from desertfire August 7, 2025 05:54
if (scale_a.numel() == 1) {
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
} else {
// onednn 3.7 not support per token src scale, so use post mul work around
Copy link
Collaborator

@guangyey guangyey Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, XPU uses onednn 3.8. Do you know which version onednn will support this feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current onednn does not support this feature

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/140972/head returned non-zero exit code 1

Rebasing (1/21)
Rebasing (2/21)
Rebasing (3/21)
Rebasing (4/21)
Rebasing (5/21)
Rebasing (6/21)
Rebasing (7/21)
Rebasing (8/21)
Rebasing (9/21)
Rebasing (10/21)
Rebasing (11/21)
Rebasing (12/21)
Rebasing (13/21)
Rebasing (14/21)
Rebasing (15/21)
Rebasing (16/21)
Rebasing (17/21)
Rebasing (18/21)
Rebasing (19/21)
Auto-merging third_party/xpu.txt
CONFLICT (content): Merge conflict in third_party/xpu.txt
error: could not apply 70416f47c16... update xpu-ops pin
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 70416f47c16... # update xpu-ops pin

Raised by https://github.com/pytorch/pytorch/actions/runs/16930825608

auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();

// Validation checks have passed lets resize the output to actual size
Copy link
Contributor

@Stonepia Stonepia Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment actually is for

 at::native::resize_output(out, {mat1.size(0), mat2.size(1)});

at Blas.cpp?

https://github.com/pytorch/pytorch/pull/140972/files#diff-9ae74b4a8990350760237cc09e715cc25a333f1d0655bd13cddb71c62cea2a39R444

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't get your meaning. Could you explain it in more detail?

Copy link
Contributor

@Stonepia Stonepia Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry for not comment clearly. I mean, the comment of:

"validation checks have passed", this I suppose should be checks of the TORCH_CHECK(...) in the beginning of the _scaled_mm_out_xpu.

"resize the output to actual size": This should mean the resize_output(out, ...) in line 444 of Blas.cpp. Because formerly, this out is initialized as an empty 0 tensor (line 459).

So from my understanding, this comment should indicate the line 444, rather than this oneDNN size creation. But it is not a big deal anyway.

Copy link
Contributor Author

@yuchengliu1 yuchengliu1 Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is not only used in _scaled_mm_xpu. It can be called like torch._scaled_mm(a,b,...,out). out is an external variables here, and it may need to be resized.

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 14, 2025
@yuchengliu1 yuchengliu1 closed this Nov 4, 2025
@github-project-automation github-project-automation bot moved this from Pre-Review Required to Done in PyTorch Intel Nov 4, 2025
pytorchmergebot pushed a commit that referenced this pull request Nov 14, 2025
This PR implements `scaled_mm` for XPU. It enables the following data types:
1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2`
2. RowWise Scaling:  `fp8_e4m3` and `fp8_e5m2`

It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.

This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts.

Secondly, there is a `scaled_mm_v2` API in #164141 . We will align with it once the v1 is cleaned up.

**Co-author:** @yuchengliu1, @carsonwang

## PR stack:

- -> #165978 : implementation of XPU scaled_mm and oneDNN kernel
- #167518 : implementation of XPU scaled_mm_v2
- #166056 : Op registration

## Test Status:

1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU.
2. This PR does not include tests, the tests are enabled in #166056.

## Credit:

This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.

## FP8 Task tracker:
We will track all the scaled_mm related tasks in: #167170

Pull Request resolved: #165978
Approved by: https://github.com/liangan1, https://github.com/EikanWang

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
…h#165978)

This PR implements `scaled_mm` for XPU. It enables the following data types:
1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2`
2. RowWise Scaling:  `fp8_e4m3` and `fp8_e5m2`

It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.

This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts.

Secondly, there is a `scaled_mm_v2` API in pytorch#164141 . We will align with it once the v1 is cleaned up.

**Co-author:** @yuchengliu1, @carsonwang

## PR stack:

- -> pytorch#165978 : implementation of XPU scaled_mm and oneDNN kernel
- pytorch#167518 : implementation of XPU scaled_mm_v2
- pytorch#166056 : Op registration

## Test Status:

1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU.
2. This PR does not include tests, the tests are enabled in pytorch#166056.

## Credit:

This work is based on @yuchengliu1's work at pytorch#140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.

## FP8 Task tracker:
We will track all the scaled_mm related tasks in: pytorch#167170

Pull Request resolved: pytorch#165978
Approved by: https://github.com/liangan1, https://github.com/EikanWang

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
chuanqi129 pushed a commit to intel/torch-xpu-ops that referenced this pull request Nov 21, 2025
scaled_mm has supported in pytorch
pytorch/pytorch#140972
This fallback will cause duplicate registration. 
remove this fallback after
pytorch/pytorch#140972 merged

---------

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Co-authored-by: Su Tong <tong.su@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/xpu Run XPU CI tasks module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration open source release notes: inductor (aoti) release notes: quantization release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

10 participants