[xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU#165978
[xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU#165978Stonepia wants to merge 26 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165978
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 949d894 with merge base b91a2ab ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
|
@Stonepia This PR is still very large and the review effort is heavy. Only the tensor- and row-wise scaling is supported in this PR, suggest to remove the unsupported scaling format process, but keep the design extendable to add more scaling format. |
|
@pytorchbot label "module: xpu" |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
435d529 to
5814f66
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: xpu / linux-noble-xpu-n-py3.10 / test (default, 6, 12, linux.idc.xpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: xpu / linux-noble-xpu-n-py3.10 / test (default, 6, 12, linux.idc.xpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…67518) This PR implements `scaled_mm_v2` for XPU follows the work in #164141 . ## PR stack: - #165978 : implementation of XPU scaled_mm and oneDNN kernel - -> #167518 : implementation of XPU scaled_mm_v2 - #166056 : Op registration Pull Request resolved: #167518 Approved by: https://github.com/EikanWang, https://github.com/liangan1
…h#165978) This PR implements `scaled_mm` for XPU. It enables the following data types: 1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2` 2. RowWise Scaling: `fp8_e4m3` and `fp8_e5m2` It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts. This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts. Secondly, there is a `scaled_mm_v2` API in pytorch#164141 . We will align with it once the v1 is cleaned up. **Co-author:** @yuchengliu1, @carsonwang ## PR stack: - -> pytorch#165978 : implementation of XPU scaled_mm and oneDNN kernel - pytorch#167518 : implementation of XPU scaled_mm_v2 - pytorch#166056 : Op registration ## Test Status: 1. Relies on the changes in intel/torch-xpu-ops#1746, Otherwise the op will fallback to CPU. 2. This PR does not include tests, the tests are enabled in pytorch#166056. ## Credit: This work is based on @yuchengliu1's work at pytorch#140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts. ## FP8 Task tracker: We will track all the scaled_mm related tasks in: pytorch#167170 Pull Request resolved: pytorch#165978 Approved by: https://github.com/liangan1, https://github.com/EikanWang Co-authored-by: Eikan Wang <eikan.wang@intel.com>
…xpu (#166056) This PR registers the `scaled_mm` op for XPU support. It does the following: 1. Registered the `_scaled_mm` and `_scaled_mm_v2` op for XPU. 2. Enables XPU tests in `test_scaled_matmul_cuda.py`. 3. Update torch-xpu-ops pin to remove fallback `scaled_mm` to CPU implementation. ## PR Stack: - #165978 : implementation of XPU scaled_mm and oneDNN kernel - #167518 : implementation of XPU scaled_mm_v2 - -> #166056 : Op registration ## Task tracker: We will track all the scaled_mm related tasks in: #167170 Pull Request resolved: #166056 Approved by: https://github.com/EikanWang, https://github.com/slayton58, https://github.com/drisspg
…xpu (#166056) This PR registers the `scaled_mm` op for XPU support. It does the following: 1. Registered the `_scaled_mm` and `_scaled_mm_v2` op for XPU. 2. Enables XPU tests in `test_scaled_matmul_cuda.py`. 3. Update torch-xpu-ops pin to remove fallback `scaled_mm` to CPU implementation. ## PR Stack: - #165978 : implementation of XPU scaled_mm and oneDNN kernel - #167518 : implementation of XPU scaled_mm_v2 - -> #166056 : Op registration ## Task tracker: We will track all the scaled_mm related tasks in: #167170 Pull Request resolved: #166056 Approved by: https://github.com/EikanWang, https://github.com/slayton58, https://github.com/drisspg
This PR implements
scaled_mmfor XPU. It enables the following data types:fp8_e4m3andfp8_e5m2fp8_e4m3andfp8_e5m2It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.
This is the first PR that only adds
scaled_mm_xpubut does not registered. We separate this out for less reviewing efforts.Secondly, there is a
scaled_mm_v2API in #164141 . We will align with it once the v1 is cleaned up.Co-author: @yuchengliu1, @carsonwang
PR stack:
scaled_mmandscaled_mm_v2for xpu #166056 : Op registrationTest Status:
scaled_mmandscaled_mm_v2for xpu #166056.Credit:
This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.
FP8 Task tracker:
We will track all the scaled_mm related tasks in: #167170
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01 @gujinghui @EikanWang @fengyuan14 @guangyey