Skip to content

[MPS] optimize cholesky#145722

Closed
Isalia20 wants to merge 2 commits intopytorch:mainfrom
Isalia20:mps-cholesky-optimization
Closed

[MPS] optimize cholesky#145722
Isalia20 wants to merge 2 commits intopytorch:mainfrom
Isalia20:mps-cholesky-optimization

Conversation

@Isalia20
Copy link
Copy Markdown
Collaborator

@Isalia20 Isalia20 commented Jan 27, 2025

Followup to #145701

Optimizes the syrk and trsm kernels of cholesky decomposition on mps. For SYRK kernel it does matmuls with apple's simdgroup matrices instead of a tiled implementation and for trsm kernel we do vectorized loads. Also this PR puts command encoder inside of the stream queue dispatch (as discussed on last PR).

Script to collect perf

import torch
import numpy as np
import time
import csv

matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3

def create_spd_matrix(n, batch_size):
    torch.manual_seed(42)
    A = torch.randn(batch_size, n, n, dtype=torch.float32)
    return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

def run_cholesky_mps(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    b = torch.linalg.cholesky(A, upper=False)
    torch.mps.synchronize()
    end = time.perf_counter()
    return b, end - start

results = {
    'N': [],
    'batch_size': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    for batch_size in batch_sizes:
        print(f"\nBenchmarking N={n}, batch_size={batch_size}")
        
        try:
            A_cpu = create_spd_matrix(n, batch_size)
            A_mps = A_cpu.to("mps")
            
            for _ in range(warmup_runs):
                _, _ = run_cholesky_mps(A_mps)
            
            times = []
            for _ in range(num_runs):
                _, t = run_cholesky_mps(A_mps)
                times.append(t)
            
            mean_time = np.mean(times)
            std_time = np.std(times)
            
            results['N'].append(n)
            results['batch_size'].append(batch_size)
            results['mean_time'].append(mean_time)
            results['std_time'].append(std_time)
            
            print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
            
        except RuntimeError as e:
            print(f"Error for N={n}, batch_size={batch_size}: {e}")
            continue

with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['batch_size'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])

Observed speedups on M1 Pro
cholesky_speedup

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145722

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 42 Pending

As of commit 6b8eb58 with merge base 0f5a683 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Jan 27, 2025
Comment on lines +205 to +206
// forward substitution with loop unrolling and vectorization
#pragma unroll 4
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat annoying to me, for some reason lintrunner removes indentation here

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 29, 2025
Copy link
Copy Markdown
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, though it would be nice to add some description on perf before/after

@malfet malfet added ciflow/mps Run MPS tests (subset of trunk) module: mps Related to Apple Metal Performance Shaders framework labels Jan 31, 2025
@Isalia20
Copy link
Copy Markdown
Collaborator Author

Isalia20 commented Jan 31, 2025

Speed improvements over the old kernel(Benchmarked on M1 Pro):

cholesky_speedup

For benchmarking one can use the below script, some basic packages like numpy/pandas/matplotlib needed. Usage:

  1. Compile with old kernel
  2. Run the below script
  3. Rename saved csv to cholesky_benchmark_times_old.csv
  4. Compile with new kernel
  5. Run the below script
  6. Rename saved csv to cholesky_benchmark_times_new.csv
  7. Run the script after this
import torch
import numpy as np
import time
import csv

matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3

def create_spd_matrix(n, batch_size):
    torch.manual_seed(42)
    A = torch.randn(batch_size, n, n, dtype=torch.float32)
    return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

def run_cholesky_mps(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    b = torch.linalg.cholesky(A, upper=False)
    torch.mps.synchronize()
    end = time.perf_counter()
    return b, end - start

results = {
    'N': [],
    'batch_size': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    for batch_size in batch_sizes:
        print(f"\nBenchmarking N={n}, batch_size={batch_size}")
        
        try:
            A_cpu = create_spd_matrix(n, batch_size)
            A_mps = A_cpu.to("mps")
            
            for _ in range(warmup_runs):
                _, _ = run_cholesky_mps(A_mps)
            
            times = []
            for _ in range(num_runs):
                _, t = run_cholesky_mps(A_mps)
                times.append(t)
            
            mean_time = np.mean(times)
            std_time = np.std(times)
            
            results['N'].append(n)
            results['batch_size'].append(batch_size)
            results['mean_time'].append(mean_time)
            results['std_time'].append(std_time)
            
            print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
            
        except RuntimeError as e:
            print(f"Error for N={n}, batch_size={batch_size}: {e}")
            continue

with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['batch_size'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])

To visualize:

import pandas as pd
import matplotlib.pyplot as plt

old_data = pd.read_csv("cholesky_benchmark_times_old.csv")
new_data = pd.read_csv("cholesky_benchmark_times_new.csv")
merged_data = pd.merge(old_data, new_data, on=["N", "batch_size"], suffixes=("_old", "_new"))
merged_data["speedup"] = merged_data["mean_time_old"] / merged_data["mean_time_new"]
pivot_table = merged_data.pivot(index="batch_size", columns="N", values="speedup")
plt.figure(figsize=(10, 6))
for N in pivot_table.columns:
    plt.plot(pivot_table.index, pivot_table[N], marker="o", label=f"N={N}")
plt.xlabel("Batch Size")
plt.ylabel("Speedup (Old Time / New Time)")
plt.title("Speedup Comparison: Old vs New Times")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend(title="Matrix Size (N)")
plt.tight_layout()
plt.savefig("cholesky_speedup.png")

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Jan 31, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants