Skip to content

Persistent row-wise kernels#2735

Closed
htyu wants to merge 1 commit intopytorch:mainfrom
htyu:export-D58117182
Closed

Persistent row-wise kernels#2735
htyu wants to merge 1 commit intopytorch:mainfrom
htyu:export-D58117182

Conversation

@htyu
Copy link
Contributor

@htyu htyu commented Jun 14, 2024

Summary: Enabling persistent kernels for row-wise fp8_fast_accum=True/False

Differential Revision: D58117182

@netlify
Copy link

netlify bot commented Jun 14, 2024

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 373750b
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/666cc4ea42405c0008ed7edf
😎 Deploy Preview https://deploy-preview-2735--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58117182

htyu added a commit to htyu/FBGEMM that referenced this pull request Jun 14, 2024
Summary:
Pull Request resolved: pytorch#2735

Enabling persistent kernels for row-wise fp8_fast_accum=True/False

Differential Revision: D58117182
@htyu htyu force-pushed the export-D58117182 branch from bbb6cd3 to 6463077 Compare June 14, 2024 07:17
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58117182

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58117182

htyu added a commit to htyu/FBGEMM that referenced this pull request Jun 14, 2024
Summary:
Pull Request resolved: pytorch#2735

Enabling persistent kernels for row-wise fp8_fast_accum=True/False

Differential Revision: D58117182
@htyu htyu force-pushed the export-D58117182 branch from 6463077 to 53759f2 Compare June 14, 2024 07:23
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58117182

htyu added a commit to htyu/FBGEMM that referenced this pull request Jun 14, 2024
Summary:
Pull Request resolved: pytorch#2735

Enabling persistent kernels for row-wise fp8_fast_accum=True/False based on the Triton upstream implemenation.

Differential Revision: D58117182
@htyu htyu force-pushed the export-D58117182 branch from 53759f2 to 133a732 Compare June 14, 2024 17:15
Summary:
Pull Request resolved: pytorch#2735

Enabling persistent kernels for row-wise fp8_fast_accum=True/False based on the Triton upstream implementation triton-lang/triton#4099.

Differential Revision: D58117182
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58117182

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 8a938d6.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 30, 2024
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in #128683.

The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in #125204) and Triton kernel configurations.

The Triton kernel template is based on htyu/FBGEMM@3ad9031 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`

## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
    - output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
        - P1477224245 - 2 kernels
    - output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
        - P1477227340 - 2 kernels

- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`

## Benchmarking

Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.

Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446

## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?

## Todo:
- Make the Triton template use the improved persistent kernel version (pytorch/FBGEMM#2735 by @htyu)

Pull Request resolved: #130422
Approved by: https://github.com/ipiszy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants