Commit 5bc969d
committed
Update on "[MPS] Implement linear1d as shader"
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 ")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 2.5 | 2.1 | 2.2 | 161.4 | 115.0 | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16
Outputs Match : True | True | True | True | True | True
Average Time (us) : 354.0 | 336.0 | 332.4 | 145.5 | 114.7 | 148.3
```
Fixes #144245
[ghstack-poisoned]1 parent 370a3f8 commit 5bc969d
1 file changed
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6560 | 6560 | | |
6561 | 6561 | | |
6562 | 6562 | | |
6563 | | - | |
| 6563 | + | |
6564 | 6564 | | |
6565 | 6565 | | |
6566 | 6566 | | |
| |||
0 commit comments