Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145722
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 42 PendingAs of commit 6b8eb58 with merge base 0f5a683 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // forward substitution with loop unrolling and vectorization | ||
| #pragma unroll 4 |
There was a problem hiding this comment.
This is somewhat annoying to me, for some reason lintrunner removes indentation here
malfet
left a comment
There was a problem hiding this comment.
Sure, though it would be nice to add some description on perf before/after
Speed improvements over the old kernel(Benchmarked on M1 Pro):For benchmarking one can use the below script, some basic packages like numpy/pandas/matplotlib needed. Usage:
import torch
import numpy as np
import time
import csv
matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3
def create_spd_matrix(n, batch_size):
torch.manual_seed(42)
A = torch.randn(batch_size, n, n, dtype=torch.float32)
return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)
def run_cholesky_mps(A):
torch.mps.synchronize()
start = time.perf_counter()
b = torch.linalg.cholesky(A, upper=False)
torch.mps.synchronize()
end = time.perf_counter()
return b, end - start
results = {
'N': [],
'batch_size': [],
'mean_time': [],
'std_time': []
}
for n in matrix_sizes:
for batch_size in batch_sizes:
print(f"\nBenchmarking N={n}, batch_size={batch_size}")
try:
A_cpu = create_spd_matrix(n, batch_size)
A_mps = A_cpu.to("mps")
for _ in range(warmup_runs):
_, _ = run_cholesky_mps(A_mps)
times = []
for _ in range(num_runs):
_, t = run_cholesky_mps(A_mps)
times.append(t)
mean_time = np.mean(times)
std_time = np.std(times)
results['N'].append(n)
results['batch_size'].append(batch_size)
results['mean_time'].append(mean_time)
results['std_time'].append(std_time)
print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
except RuntimeError as e:
print(f"Error for N={n}, batch_size={batch_size}: {e}")
continue
with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
for i in range(len(results['N'])):
writer.writerow([
results['N'][i],
results['batch_size'][i],
results['mean_time'][i],
results['std_time'][i]
])To visualize: |
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |

Followup to #145701
Optimizes the syrk and trsm kernels of cholesky decomposition on mps. For SYRK kernel it does matmuls with apple's simdgroup matrices instead of a tiled implementation and for trsm kernel we do vectorized loads. Also this PR puts command encoder inside of the stream queue dispatch (as discussed on last PR).
Script to collect perf
Observed speedups on M1 Pro

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen