Skip to content

Commit 5bc969d

Browse files
committed
Update on "[MPS] Implement linear1d as shader"
And get rid of MPS call, as for some reason implementation via MPSGraph API call is 100x+ times slower that Metal shader, at least according to the following benchmark ```python import torch import time import subprocess def benchmark(device, dtype): # Create example inputs x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype) sf = .5 # Check output y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear") z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear") outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear") torch.mps.synchronize end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() print(f"\nBenchmarking Results (collected on {brand_string}):") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Benchmark results after the change ``` Benchmarking Results (collected on Apple M2 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1 ``` And before the change ``` Benchmarking Results (collected on Apple M2 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3 ``` Fixes #144245 [ghstack-poisoned]
1 parent 370a3f8 commit 5bc969d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6560,7 +6560,7 @@ def helper(shape, output_size, scales, mode, align_corners=False):
65606560
helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
65616561
helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
65626562
# Regression test for https://github.com/pytorch/pytorch/issues/144245
6563-
inp = torch.tensor([[[1.]],[[2]],[[4]]], device='mps')
6563+
inp = torch.tensor([[[1.]], [[2]], [[4]]], device='mps')
65646564
for align_corners in [True, False]:
65656565
def interp(x):
65666566
return F.interpolate(x, 3, mode='linear', align_corners=align_corners)

0 commit comments

Comments
 (0)