Skip to content

FP8 rowwise scaling#125204

Closed
drisspg wants to merge 0 commit intopytorch:mainfrom
drisspg:add-row-wise-scaling
Closed

FP8 rowwise scaling#125204
drisspg wants to merge 0 commit intopytorch:mainfrom
drisspg:add-row-wise-scaling

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Apr 30, 2024

Summary

This pull request introduces an fp8 row-scaling kernel as an optional implementation for scaled_mm. The kernel selection is based on the scaling tensors of the inputs. For inputs x and y of shape [M, K] and [K, N] respectively, the following conditions must be met:

  • x's scale should be a 1-dimensional tensor of length M.
  • y's scale should be a 1-dimensional tensor of length N.

It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for y are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".

The following two PRs were required to enable local builds:

Todo

We still do not build our Python wheels with this architecture.

@ptrblck @malfet, should we replace sm_90 with sm_90a?

The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954

ifdef

I tried to use : #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900 to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this

Kernel Credit:
@jwfromm

cc @yanbing-j @vkuzo @albanD @kadeng

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125204

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (31 Unrelated Failures)

As of commit 4448397 with merge base 4448397 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the add-row-wise-scaling branch 7 times, most recently from 54a84cc to dac6a96 Compare May 2, 2024 02:00
@drisspg drisspg force-pushed the add-row-wise-scaling branch from dac6a96 to 110261b Compare May 2, 2024 18:31
@drisspg drisspg requested a review from malfet May 2, 2024 19:27
@drisspg drisspg force-pushed the add-row-wise-scaling branch 7 times, most recently from 7d9bc17 to 73b3a39 Compare May 20, 2024 21:51
@drisspg drisspg force-pushed the add-row-wise-scaling branch from 73b3a39 to 63c30ed Compare May 23, 2024 22:28
@drisspg
Copy link
Contributor Author

drisspg commented May 23, 2024

❯ nm  -C /home/drisspg/meta/pytorch/torch/lib/libtorch_cuda.so | grep cuT
0000000002561b90 T cuTensorMapEncodeTiled
0000000000ef1110 t at::cuda::detail::_stubs::cuTensorMapEncodeTiled(CUtensorMap_st*, CUtensorMapDataType_enum, unsigned int, void*, unsigned long const*, unsigned long const*, unsigned int const*, unsigned int const*, CUtensorMapInterleave_enum, CUtensorMapSwizzle_enum, CUtensorMapL2promotion_enum, CUtensorMapFloatOOBfill_enum)
0000000000cf4fa7 t at::cuda::detail::_stubs::cuTensorMapEncodeTiled(CUtensorMap_st*, CUtensorMapDataType_enum, unsigned int, void*, unsigned long const*, unsigned long const*, unsigned int const*, unsigned int const*, CUtensorMapInterleave_enum, CUtensorMapSwizzle_enum, CUtensorMapL2promotion_enum, CUtensorMapFloatOOBfill_enum) [clone .cold]

This symbol shadowing doesnt seem right

@drisspg
Copy link
Contributor Author

drisspg commented May 24, 2024

After some preproc shenanigans I think I got it in a state that seems better but would love some feedback from packaging experts:

❯ nm -C /home/drisspg/meta/pytorch/torch/lib/libtorch_cuda.so | grep cuT;
0000000002561680 t nvrtc_cuTensorMapEncodeTiled(CUtensorMap_st*, CUtensorMapDataType_enum, unsigned int, void*, unsigned long const*, unsigned long const*, unsigned int const*, unsigned int const*, CUtensorMapInterleave_enum, CUtensorMapSwizzle_enum, CUtensorMapL2promotion_enum, CUtensorMapFloatOOBfill_enum) [clone .constprop.1]
0000000000ef10c0 t at::cuda::detail::_stubs::cuTensorMapEncodeTiled(CUtensorMap_st*, CUtensorMapDataType_enum, unsigned int, void*, unsigned long const*, unsigned long const*, unsigned int const*, unsigned int const*, CUtensorMapInterleave_enum, CUtensorMapSwizzle_enum, CUtensorMapL2promotion_enum, CUtensorMapFloatOOBfill_enum)
0000000000cf4f57 t at::cuda::detail::_stubs::cuTensorMapEncodeTiled(CUtensorMap_st*, CUtensorMapDataType_enum, unsigned int, void*, unsigned long const*, unsigned long const*, unsigned int const*, unsigned int const*, CUtensorMapInterleave_enum, CUtensorMapSwizzle_enum, CUtensorMapL2promotion_enum, CUtensorMapFloatOOBfill_enum) [clone .cold]

@drisspg drisspg force-pushed the add-row-wise-scaling branch from f8d8979 to 7577f4a Compare May 25, 2024 03:35
@drisspg drisspg marked this pull request as ready for review May 25, 2024 03:43
@drisspg drisspg requested a review from eqy as a code owner May 25, 2024 03:43
@drisspg drisspg force-pushed the add-row-wise-scaling branch from 7577f4a to e8510c6 Compare May 25, 2024 16:52
@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label May 25, 2024
@pytorchmergebot
Copy link
Collaborator

@drisspg
Copy link
Contributor Author

drisspg commented Jun 5, 2024

@pytorchbot merge -f "I don think these failures are related"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
# Summary
This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.

It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".

The following two PRs were required to enable local builds:
- [PR pytorch#126185](pytorch#126185)
- [PR pytorch#125523](pytorch#125523)

### Todo
We still do not build our Python wheels with this architecture.

@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?

The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954

#### ifdef

I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
    defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this

Kernel Credit:
@jwfromm

Pull Request resolved: pytorch#125204
Approved by: https://github.com/lw
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
This reverts commit 923edef.

Reverted pytorch#125204 on behalf of https://github.com/atalman due to Broke nightlies and internal tests ([comment](pytorch#125204 (comment)))
@atalman
Copy link
Contributor

atalman commented Jun 6, 2024

@pytorchmergebot revert -c ghfirst -m "Sorry need to revert this failing, on internal CI. I suggest to reimport this and try to land internally resolving all issues"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Jun 6, 2024
This reverts commit 5dc9128.

Reverted #125204 on behalf of https://github.com/atalman due to Sorry need to revert this failing, on internal CI. I suggest to reimport this and try to land internally resolving all issues ([comment](#125204 (comment)))
@pytorchmergebot
Copy link
Collaborator

@drisspg your PR has been successfully reverted.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
This reverts commit 5dc9128.

Reverted pytorch#125204 on behalf of https://github.com/atalman due to Sorry need to revert this failing, on internal CI. I suggest to reimport this and try to land internally resolving all issues ([comment](pytorch#125204 (comment)))
@drisspg drisspg closed this Jun 18, 2024
@drisspg drisspg force-pushed the add-row-wise-scaling branch from 4cdd02a to 4448397 Compare June 18, 2024 19:08
pytorchmergebot pushed a commit that referenced this pull request Jun 19, 2024
# Summary
First PR got reverted and needed a redo

This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.

It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".

The following two PRs were required to enable local builds:
- [PR #126185](#126185)
- [PR #125523](#125523)

### Todo
We still do not build our Python wheels with this architecture.

@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?

The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954

#### ifdef

I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
    defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this

Kernel Credit:
@jwfromm

Pull Request resolved: #128989
Approved by: https://github.com/yangsiyu007, https://github.com/vkuzo
@cora-codes
Copy link
Contributor

@drisspg how should we resolve this for now on the extension side? <ATen/cuda/nvrtc_stub/ATenNVRTC.h> cannot be used by C++ extensions.

pytorchmergebot pushed a commit that referenced this pull request Jul 30, 2024
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in #128683.

The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in #125204) and Triton kernel configurations.

The Triton kernel template is based on htyu/FBGEMM@3ad9031 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`

## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
    - output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
        - P1477224245 - 2 kernels
    - output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
        - P1477227340 - 2 kernels

- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`

## Benchmarking

Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.

Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446

## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?

## Todo:
- Make the Triton template use the improved persistent kernel version (pytorch/FBGEMM#2735 by @htyu)

Pull Request resolved: #130422
Approved by: https://github.com/ipiszy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/binaries Trigger all binary build and upload jobs on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: floatx (formerly float8) For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float types Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.