Skip to content

use cooperative schedule in scaled_mm for fast_accum=false#144809

Closed
ngimel wants to merge 1 commit into
mainfrom
ngimel/scaled_mm_coop
Closed

use cooperative schedule in scaled_mm for fast_accum=false#144809
ngimel wants to merge 1 commit into
mainfrom
ngimel/scaled_mm_coop

Conversation

@ngimel

@ngimel ngimel commented Jan 14, 2025

Copy link
Copy Markdown
Collaborator

This improves perf for large matrices by more than 2x, more detailed benchmark coming.
On master
image
On this branch
image
A plot similar to pytorch/ao#1325 (comment)

Benchmarking code:
import torch
from triton.testing import do_bench
import itertools

def fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False):
    return torch._scaled_mm(a, b.t(), scale_a.view(-1, 1), scale_b.view(1, -1), use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16)

def fn_aten(a, b, scale, use_fast_accum=False):
    return torch._scaled_mm(a, b.t(), scale, scale, use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16)

for i,j,k in itertools.product(range(9, 15), range(9, 15), range(9, 15)):
    m = 2**i
    n = 2**j
    k = 2**k

    a=torch.randn(m, k, device="cuda").to(dtype=torch.float8_e4m3fn)
    b=torch.randn(n, k, device="cuda").to(dtype=torch.float8_e4m3fn)
    scale_a = torch.randint(1, 11, (a.shape[0],), device="cuda", dtype=torch.float32)
    scale_b = torch.randint(1, 11, (b.shape[0],), device="cuda", dtype=torch.float32)
    scale_0 = torch.randn((), device="cuda", dtype=torch.float32)

    ms_rowwise_fast = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=True), warmup=25, rep=50)
    ms_rowwise_slow = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False), warmup=25, rep=50)

    ms_tensor_fast = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=True), warmup=25, rep=50)
    ms_tensor_slow = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=False), warmup=25, rep=50)

    print(f"m={m}, n={n}, k={k}, fast={ms_rowwise_fast}, slow={ms_rowwise_slow}, ratio_tw={ms_tensor_slow /ms_tensor_fast}, ratio_rw={ms_rowwise_slow / ms_rowwise_fast}")

Higher N/K values still have about 40% penalty, perhaps some additional heuristics tweaks would be useful.

@ngimel ngimel requested review from eqy and syed-ahmed as code owners January 14, 2025 23:12
@pytorch-bot

pytorch-bot Bot commented Jan 14, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144809

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6e61be7 with merge base 64bcf39 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added the release notes: cuda release notes category label Jan 14, 2025
@ngimel ngimel requested review from drisspg and lw January 14, 2025 23:13
@ngimel

ngimel commented Jan 15, 2025

Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 15, 2025
@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants