Skip to content

Remove amax return from _scaled_mm#128683

Closed
drisspg wants to merge 1 commit intopytorch:mainfrom
drisspg:remove-low-precision-option
Closed

Remove amax return from _scaled_mm#128683
drisspg wants to merge 1 commit intopytorch:mainfrom
drisspg:remove-low-precision-option

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Jun 14, 2024

Summary

The primary reason for the change was lack of current use case and the need to work around an two Inductor issue.

  • Tensor arguments as kwarg only
  • multiple outputs from triton templates

If the need for the amax return type arises we can consider either adding it, more likely creating a separate op.

In principle PyTorch is moving away from ops that bundle lots of functionality into "mega ops". We instead rely upon the compiler to generate appropriate fused kernels.

Changes:

  • This removes the amax return type from scaled_mm. We have found that the common use case is to return in "high-precision" ( a type with more precision than fp8). This is only relevant when returning in low-precision.
  • We currently still allow for fp8 returns and scaled result. Perhaps we should also ban this as well...

New signature:

def meta_scaled_mm(
    self: torch.Tensor,
    mat2: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    scale_result: Optional[torch.Tensor] = None,
    out_dtype: Optional[torch.dtype] = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128683

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (14 Unrelated Failures)

As of commit 476e817 with merge base f8d60e0 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the remove-low-precision-option branch from 4abcd73 to 1b7669e Compare June 14, 2024 04:58
@drisspg drisspg changed the title Remove amax return types Remove amax return from _scaled_mm Jun 14, 2024
@drisspg drisspg force-pushed the remove-low-precision-option branch 4 times, most recently from 1c46fb1 to b1f566c Compare June 14, 2024 23:46
@drisspg drisspg marked this pull request as ready for review June 14, 2024 23:46
@drisspg drisspg requested a review from eqy as a code owner June 14, 2024 23:46
@drisspg drisspg requested review from vkuzo and yangsiyu007 June 15, 2024 00:04
@drisspg drisspg added the topic: not user facing topic category label Jun 15, 2024
@drisspg drisspg force-pushed the remove-low-precision-option branch 2 times, most recently from 5a24d53 to dc1313b Compare June 15, 2024 01:55
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! lg if tests pass

@vkuzo
Copy link
Contributor

vkuzo commented Jun 15, 2024

for my own curiosity, what was the reason for making scales required?

@drisspg
Copy link
Contributor Author

drisspg commented Jun 15, 2024

cc @yangsiyu007 on the inductor constraint on scales, not being optional.

That being said, in retrospect I think this makes more sense. I think it makes sense that "scaled_mm" requires the scales = lol since it is pretty rare (modulo testing) that proper use of this function doesnt require scales

@yangsiyu007
Copy link
Contributor

yangsiyu007 commented Jun 15, 2024

[Edited] Checked that lowering now works, output: P1419523227
You can see that for tensor-wise scaling which ran first, AUTOTUNE happened between ATen _scaled_mm and the Triton templated kernels’ configs. And for rowwise scaling next, only Triton configs were tuned.

@yangsiyu007
Copy link
Contributor

for my own curiosity, what was the reason for making scales required?

Inductor doesn't support optional tensor inputs currently; the symptom is that it will check the layout of each input tensor and errors at seeing None (I tried a workaround with giving it an empty TensorBox, but that leads to incorrect codegen because some pass drops the unused nodes). There is a workaround for handling only 1 optional tensor input, which is why we are okay with the optional bias (by having two Triton templates, with and without bias). I'd like to work on supporting optional tensor inputs, but since it makes sense for the scales to be non-optional, we thought we'd precede for now.

@drisspg drisspg force-pushed the remove-low-precision-option branch from dc1313b to 476e817 Compare June 17, 2024 03:01
@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 17, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Jun 17, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg
Copy link
Contributor Author

drisspg commented Jun 17, 2024

@pytorchbot merge -i

drisspg added a commit to drisspg/pytorch that referenced this pull request Jun 19, 2024
Summary:
Pull Request resolved: pytorch#129037

This forward fixes this diff:
D58699985

Since we have a few things in flight it would be much better to forward fix this test

Test Plan: buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:test_inductor_cuda -- --exact 'caffe2/test/inductor:test_inductor_cuda - test_red_followed_by_transposed_pointwise (caffe2.test.inductor.test_torchinductor.TritonCodeGenTests)'

Differential Revision: D58767577
pytorchmergebot pushed a commit that referenced this pull request Jun 22, 2024
Summary:
This forward fixes this diff:
D58699985

Since we have a few things in flight it would be much better to forward fix this test

Test Plan: buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:test_inductor_cuda -- --exact 'caffe2/test/inductor:test_inductor_cuda - test_red_followed_by_transposed_pointwise (caffe2.test.inductor.test_torchinductor.TritonCodeGenTests)'

Differential Revision: D58767577

Pull Request resolved: #129037
Approved by: https://github.com/vkuzo
pytorchmergebot pushed a commit that referenced this pull request Jul 12, 2024
`_scaled_mm` no longer returns `amax` (see #128683)

Pull Request resolved: #130582
Approved by: https://github.com/drisspg
pytorchmergebot pushed a commit that referenced this pull request Jul 22, 2024
… cases (#130868)

Continuing #128683 and #130582.

The api of _scaled_mm has changed. For example, there is only one return now. So change the aoti api as well.

Also, tested the fp8 tests offline. The test_fp8_abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface would fail with `error: use of undeclared identifier 'float8_e4m3fn'` and `error: use of undeclared identifier 'half'`, so skipping them for now.

The reason this wasn't known earlier is probably because the CI doesn't use H100.

Pull Request resolved: #130868
Approved by: https://github.com/drisspg, https://github.com/chenyang78, https://github.com/desertfire
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
`_scaled_mm` no longer returns `amax` (see pytorch#128683)

Pull Request resolved: pytorch#130582
Approved by: https://github.com/drisspg
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
… cases (pytorch#130868)

Continuing pytorch#128683 and pytorch#130582.

The api of _scaled_mm has changed. For example, there is only one return now. So change the aoti api as well.

Also, tested the fp8 tests offline. The test_fp8_abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface would fail with `error: use of undeclared identifier 'float8_e4m3fn'` and `error: use of undeclared identifier 'half'`, so skipping them for now.

The reason this wasn't known earlier is probably because the CI doesn't use H100.

Pull Request resolved: pytorch#130868
Approved by: https://github.com/drisspg, https://github.com/chenyang78, https://github.com/desertfire
pytorchmergebot pushed a commit that referenced this pull request Jul 30, 2024
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in #128683.

The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in #125204) and Triton kernel configurations.

The Triton kernel template is based on htyu/FBGEMM@3ad9031 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`

## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
    - output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
        - P1477224245 - 2 kernels
    - output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
        - P1477227340 - 2 kernels

- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`

## Benchmarking

Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.

Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446

## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?

## Todo:
- Make the Triton template use the improved persistent kernel version (pytorch/FBGEMM#2735 by @htyu)

Pull Request resolved: #130422
Approved by: https://github.com/ipiszy
pytorchmergebot pushed a commit that referenced this pull request Sep 9, 2024
amax was removed from _scaled_mm by #128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: #135421
Approved by: https://github.com/drisspg, https://github.com/eqy
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: pytorch#135421
Approved by: https://github.com/drisspg, https://github.com/eqy
amd-sriram pushed a commit to ROCm/pytorch that referenced this pull request Nov 19, 2024
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: pytorch#135421
Approved by: https://github.com/drisspg, https://github.com/eqy
jeffdaily added a commit to ROCm/pytorch that referenced this pull request Nov 21, 2024
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: pytorch#135421
Approved by: https://github.com/drisspg, https://github.com/eqy
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Nov 21, 2024
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal
at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to find
additional solutions rather than forcing amax to be used and then
discarding the result. Pull Request resolved:
pytorch#135421 Approved by:
https://github.com/drisspg, https://github.com/eqy
amd-sriram pushed a commit to ROCm/pytorch that referenced this pull request Nov 22, 2024
amax was removed from _scaled_mm by pytorch#128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: pytorch#135421
Approved by: https://github.com/drisspg, https://github.com/eqy
amd-sriram added a commit to ROCm/pytorch that referenced this pull request Nov 22, 2024
amd-sriram added a commit to ROCm/pytorch that referenced this pull request Dec 2, 2024
…comparison in the unit test, removing skip rocm decorator with cherry pick of 3ea3914
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Dec 5, 2024
…vs_emulated_*float*_cuda and Updating unit test case based on removing amax from _scaled_mm (#1762)

Fixes ROCm/frameworks-internal#8493 and
ROCm/frameworks-internal#10198

cherry pick commit -
39a6179

`amax was removed from _scaled_mm by pytorch#128683. Remove it from the
internal at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to
find additional solutions rather than forcing amax to be used and then
discarding the result.`

Also removing amax comparison in the unit test.
pytorchmergebot pushed a commit that referenced this pull request Jan 20, 2025
Looks like `out_fp8` should use matmul without scales and `out_fp8_s` with
Scales were optional arguments before PR #128683
Then test_float8_scale started comparing two identical results and lost its meaning
Reason of making scales required #128683 (comment)

This PR uses scale=1.0 to compare result with scaled matmul

Pull Request resolved: #143912
Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/pruthvistony
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Mar 17, 2025
…vs_emulated_*float*_cuda and Updating unit test case based on removing amax from _scaled_mm (#1762)

Fixes ROCm/frameworks-internal#8493 and
ROCm/frameworks-internal#10198

cherry pick commit -
39a6179

`amax was removed from _scaled_mm by pytorch#128683. Remove it from the
internal at::cuda::blas::scaled_gemm, as well. This allows hipBLASLt to
find additional solutions rather than forcing amax to be used and then
discarding the result.`

Also removing amax comparison in the unit test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants