[MPS][BE] Do not create 4 instances of FUSED_ADAM_OPS#141090
[MPS][BE] Do not create 4 instances of FUSED_ADAM_OPS#141090malfet wants to merge 2 commits intogh/malfet/60/basefrom
FUSED_ADAM_OPS#141090Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141090
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 34e00fc with merge base 0443398 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge -f "Mac builds + lint looks fine" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #141092 Approved by: https://github.com/Skylion007, https://github.com/kulinseth ghstack dependencies: #141089, #141090
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions Pull Request resolved: #141103 Approved by: https://github.com/kulinseth ghstack dependencies: #141089, #141090, #141092
For MacOS14+
Running following script
```python
```
Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
| Fused: True | Fused: False
1 threads: ----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500
Times are in microseconds (us).
```
Pull Request resolved: #141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
For MacOS14+ Running following script (adapted from one mentioned in #127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: #141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #141089, #141090, #141092, #141103
Defining `static char shaderSource[]` in the header will instantiate it as often as it is included. Solved the problem by renaming `static auto getCPLState(const std::string&)` into `auto getFusedAdamCPLState(const std::string&)` and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint) I.e. before ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357744 Nov 19 17:58 lib/libtorch_cpu.dylib ``` and afer ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357120 Nov 19 17:57 lib/libtorch_cpu.dylib ``` Pull Request resolved: pytorch#141090 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#141089
Pull Request resolved: pytorch#141092 Approved by: https://github.com/Skylion007, https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions Pull Request resolved: pytorch#141103 Approved by: https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092
For MacOS14+
Running following script
```python
```
Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
| Fused: True | Fused: False
1 threads: ----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500
Times are in microseconds (us).
```
Pull Request resolved: pytorch#141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
For MacOS14+ Running following script (adapted from one mentioned in pytorch#127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
Defining `static char shaderSource[]` in the header will instantiate it as often as it is included Solved the problem by renaming `static auto getCPLState(const std::string&)` into `auto getFusedAdamCPLState(const std::string&)` and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint) I.e. before ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357744 Nov 19 17:58 lib/libtorch_cpu.dylib ``` and afer ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357120 Nov 19 17:57 lib/libtorch_cpu.dylib ``` ghstack-source-id: 4f4a97c Pull Request resolved: pytorch/pytorch#141090
Stack from ghstack (oldest at bottom):
FUSED_ADAM_OPS#141090Defining
static char shaderSource[]in the header will instantiate it as often as it is included.Solved the problem by renaming
static auto getCPLState(const std::string&)intoauto getFusedAdamCPLState(const std::string&)and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint)I.e. before
and afer