Convert MPS Tensor data using MPSGraph API#78092
Convert MPS Tensor data using MPSGraph API#78092lhoenig wants to merge 11 commits intopytorch:masterfrom
Conversation
When doing an MPS to MPS copy and source and destination have different dtypes, build and run an MPSGraph with one call to [MPSGraph castTensor:toType:name] to actually convert tensor data
🔗 Helpful links
❌ 5 New FailuresAs of commit d99deb1 (more details on the Dr. CI page): Expand to see more
🕵️ 5 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
|
Here is a small program I used to test the approach: https://gist.github.com/lhoenig/bb0636548b2abbe690e81a3fad139ee3 |
|
@kulinseth does this work better than your WIP solution? Also if we go with this, we should add tests. |
|
Just added commits that also fix dtype conversion in This way the offsets in the blit copy calls can still be used in both transfer copy functions. Before: In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1067869824, dtype=torch.int32)
In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1067869798, device='mps:0', dtype=torch.int32)After: In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1, dtype=torch.int32)
In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1, device='mps:0', dtype=torch.int32) |
Thanks @lhoenig , the fix makes sense. I am testing that, right now. The casting is indeed required if the types are different when copying over. I will have some comments after i finish testing. Will post here. |
Thanks @albanD . Totally agree, please add tests in test_mps.py file. The ones you have in the description will be good. |
This is a good change, lets go with this. @lhoenig , please add the requested changes and we should be good. |
Awesome, thanks for the fast review, will get to it quickly! |
|
@kulinseth Ok, done for now, have also added a test in |
Its failing with: |
|
@lhoenig , I have some fixes, can i upload as a commit to this PR ? |
Sure! |
|
@kulinseth I too uploaded fixes now - hope thats ok! In this codepath, |
Awesome, these were the fixes i had locally. Now all the tests are passing. |
|
FWIW: I also tested this PR and now our transformer models work (with one additional fix). Thanks a bunch! |
|
The issues in the failure are not related to this PR: |
|
@pytorchbot merge this please. |
|
Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:lazysjb,HDCharles,bzinodev,huangyi1979,protonu, ... |
|
@pytorchbot merge this |
|
Hey @lhoenig. |
Summary: Fixes #78091 If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~. Before: ```python In [5]: pt.full((40,), -10.3, device="mps") Out[5]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int() Out[6]: tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883], device='mps:0', dtype=torch.int32) In [7]: pt.full((40,), -10.3, device="mps").int().float() Out[7]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [8]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[8]: tensor([ True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True], device='mps:0') ``` After: ```python In [3]: pt.full((40,), -10.3, device="mps") Out[3]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [4]: pt.full((40,), -10.3, device="mps").int() Out[4]: tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10], device='mps:0', dtype=torch.int32) In [5]: pt.full((40,), -10.3, device="mps").int().float() Out[5]: tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[6]: tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], device='mps:0') ``` Pull Request resolved: #78092 Approved by: https://github.com/kulinseth, https://github.com/malfet Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a52bfe2c5d8588b8f9e83e0beecdd18a1d672d0e Reviewed By: mehtanirav Differential Revision: D36668700 fbshipit-source-id: 380086b940c6a1c83e5a9d1cfb5cd40802153c56
Fixes pytorch#78091 If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~. Before: ```python In [5]: pt.full((40,), -10.3, device="mps") Out[5]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int() Out[6]: tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883], device='mps:0', dtype=torch.int32) In [7]: pt.full((40,), -10.3, device="mps").int().float() Out[7]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [8]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[8]: tensor([ True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True], device='mps:0') ``` After: ```python In [3]: pt.full((40,), -10.3, device="mps") Out[3]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [4]: pt.full((40,), -10.3, device="mps").int() Out[4]: tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10], device='mps:0', dtype=torch.int32) In [5]: pt.full((40,), -10.3, device="mps").int().float() Out[5]: tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[6]: tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], device='mps:0') ``` Pull Request resolved: pytorch#78092 Approved by: https://github.com/kulinseth, https://github.com/malfet
Fixes #78091 If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~. Before: ```python In [5]: pt.full((40,), -10.3, device="mps") Out[5]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int() Out[6]: tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883], device='mps:0', dtype=torch.int32) In [7]: pt.full((40,), -10.3, device="mps").int().float() Out[7]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [8]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[8]: tensor([ True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True], device='mps:0') ``` After: ```python In [3]: pt.full((40,), -10.3, device="mps") Out[3]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [4]: pt.full((40,), -10.3, device="mps").int() Out[4]: tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10], device='mps:0', dtype=torch.int32) In [5]: pt.full((40,), -10.3, device="mps").int().float() Out[5]: tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[6]: tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], device='mps:0') ``` Pull Request resolved: #78092 Approved by: https://github.com/kulinseth, https://github.com/malfet (cherry picked from commit a52bfe2)
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested
but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used.Before:
After: