Skip to content

Cholesky mps implementation#144193

Closed
Isalia20 wants to merge 35 commits intopytorch:mainfrom
Isalia20:cholesky-mps
Closed

Cholesky mps implementation#144193
Isalia20 wants to merge 35 commits intopytorch:mainfrom
Isalia20:cholesky-mps

Conversation

@Isalia20
Copy link
Copy Markdown
Collaborator

@Isalia20 Isalia20 commented Jan 4, 2025

Requested in #77764

PR is still in draft because it needs some cleanups and optimizations to get to cpu performance the least. Tasks:

  • Make upper=True work, only upper=False works now
  • Code cleanup
  • Optimizations(Though might need some help on this)(tried my best, maybe there is still some more to squeeze out)
  • Checks for positive definite input
  • Support for (*, N, N) input, currently only supports (B, N, N) input
  • Support other dtypes(float16, bfloat16)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144193

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 26 Unrelated Failures

As of commit dc7617d with merge base b7bef1c (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Jan 4, 2025
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Jan 4, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 5, 2025

PR is not ready yet to merge, but I'll mark it as ready for review to get some comments from you before proceeding further. There is still some things to optimize like launching a kernel to deal with syrk and trsm in parallel instead of in loops which I will tackle in the following days

@Isalia20 Isalia20 marked this pull request as ready for review January 5, 2025 09:05
@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 5, 2025

Added more optimizations but still slower than CPU 😞 Comments about improving speed are very welcome

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 5, 2025

I compare implementations like this:

import torch
import numpy as np
from torch.profiler import profile, record_function, ProfilerActivity
from torch.profiler import schedule

n = 512
batch_size = 64
num_runs = 20

torch.manual_seed(42)
A_cpu = torch.randn(batch_size, n, n, dtype=torch.float32)
A_cpu = A_cpu @ A_cpu.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

A_mps = A_cpu.to("mps")

def run_cholesky_cpu(A):
    with record_function("cholesky_cpu"):
        b = torch.linalg.cholesky(A, upper=False)
        torch.cpu.synchronize()
        return b

def run_cholesky_mps(A):
    with record_function("cholesky_mps"):
        b = torch.linalg.cholesky(A, upper=False)
        torch.mps.synchronize()
        return b

my_schedule = schedule(
    skip_first=5,
    wait=0,
    warmup=3,
    active=1,
    repeat=1
)

with profile(
    activities=[ProfilerActivity.CPU],
    schedule=my_schedule,
    record_shapes=True
) as prof:
    for _ in range(num_runs):
        run_cholesky_mps(A_mps)
        run_cholesky_cpu(A_cpu)

# Profile CPU
cpu_times = []
with profile(
    activities=[ProfilerActivity.CPU],
    record_shapes=True,
    with_stack=True
) as prof_cpu:
    for _ in range(num_runs):
        L_cpu = run_cholesky_cpu(A_cpu)

print("\nCPU Profile:")
print(prof_cpu.key_averages().table(sort_by="cpu_time_total"))

# Warmup run for MPS
_ = run_cholesky_mps(A_mps)

# Profile MPS
mps_times = []
with profile(
    activities=[ProfilerActivity.CPU],
    record_shapes=True,
    with_stack=True
) as prof_mps:
    for _ in range(num_runs):
        L_mps = run_cholesky_mps(A_mps)

print("\nMPS Profile:")
print(prof_mps.key_averages().table(sort_by="cpu_time_total"))



L_mps_cpu = L_mps.cpu()

# check
tolerance = 1e-5
assert torch.allclose(L_cpu.cpu(), L_mps_cpu, rtol=tolerance, atol=tolerance), \
    f"Maximum difference between CPU and MPS results: {torch.max(torch.abs(L_cpu - L_mps_cpu))}"

mps_events = [e for e in prof_mps.key_averages() if "cholesky_mps" in e.key]
cpu_events = [e for e in prof_cpu.key_averages() if "cholesky_cpu" in e.key]

mps_mean = np.mean([e.cpu_time * 1e-6 for e in mps_events])  # Convert to seconds
cpu_mean = np.mean([e.cpu_time * 1e-6 for e in cpu_events])

print("\nAll assertions passed! CPU and MPS results match within tolerance.")
print(f"\nCPU Cholesky mean time: {cpu_mean:.4f} seconds")
print(f"MPS Cholesky mean time: {mps_mean:.4f} seconds")
print(f"\nSpeed comparison: MPS is {cpu_mean/mps_mean:.2f}x faster than CPU on average")

@cpuhrsch cpuhrsch requested a review from albanD January 7, 2025 06:36
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 7, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@Isalia20 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jan 16, 2025
@pytorch-bot pytorch-bot bot dismissed stale reviews from malfet and malfet January 16, 2025 21:37

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 16, 2025

removed the op from fallback(originally had there before annotating linalg cholesky with CompImplicitAutograd, wrote devices like CUDA, CPU, Meta and thats why I had to have it originally in the fallback)

Tests should pass now but I suggest running the workflow here before merging @malfet

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 17, 2025

Hmm still failing. This is interesting. I'm not really sure why it's failing though. I looked into the test but don't see any cholesky ops in it. Any suggestions?

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Some of the PRs still had the issue with levit after reverting this so we can say that issues weren't caused by this. Maybe we can merge? @malfet

@Isalia20
Copy link
Copy Markdown
Collaborator Author

bump

@Isalia20
Copy link
Copy Markdown
Collaborator Author

@malfet Can we merge this? Want to submit another PR with speeding this up a lot

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jan 26, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 26, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jan 26, 2025

@pytorchbot merge -r

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/144193/head returned non-zero exit code 1

Rebasing (1/33)
Rebasing (2/33)
Rebasing (3/33)
Rebasing (4/33)
Rebasing (5/33)
Rebasing (6/33)
Rebasing (7/33)
Rebasing (8/33)
Rebasing (9/33)
Rebasing (10/33)
Rebasing (11/33)
Rebasing (12/33)
Rebasing (13/33)
Rebasing (14/33)
Rebasing (15/33)
Rebasing (16/33)
Rebasing (17/33)
Rebasing (18/33)
Rebasing (19/33)
Rebasing (20/33)
Rebasing (21/33)
Rebasing (22/33)
Rebasing (23/33)
Rebasing (24/33)
Rebasing (25/33)
Rebasing (26/33)
Auto-merging aten/src/ATen/native/native_functions.yaml
Auto-merging tools/autograd/derivatives.yaml
Auto-merging torch/_inductor/lowering.py
CONFLICT (content): Merge conflict in torch/_inductor/lowering.py
error: could not apply 372da90d281... fix failing tests
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 372da90d281... fix failing tests

Raised by https://github.com/pytorch/pytorch/actions/runs/12971799472

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jan 26, 2025

@Isalia20 do you mind rebasing this PR, because as of right now I could not get a clear signal. (Though I can always try merging it and see what will happen...)

@Isalia20
Copy link
Copy Markdown
Collaborator Author

Tried rebasing but there are lot of commits, so I just decided to create a new one checking out from main
#145701

@Isalia20 Isalia20 closed this Jan 26, 2025
pytorchmergebot pushed a commit that referenced this pull request Jan 27, 2025
Requested in #77764

Closed #144193  due to a lot of conflicts when rebasing
Pull Request resolved: #145701
Approved by: https://github.com/malfet
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this pull request Jan 27, 2025
Requested in pytorch#77764

Closed pytorch#144193  due to a lot of conflicts when rebasing
Pull Request resolved: pytorch#145701
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source release notes: mps Release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants