Skip to content

Convert MPS Tensor data using MPSGraph API#78092

Closed
lhoenig wants to merge 11 commits intopytorch:masterfrom
lhoenig:lhoenig-mps-dtype-conversion
Closed

Convert MPS Tensor data using MPSGraph API#78092
lhoenig wants to merge 11 commits intopytorch:masterfrom
lhoenig:lhoenig-mps-dtype-conversion

Conversation

@lhoenig
Copy link
Contributor

@lhoenig lhoenig commented May 23, 2022

Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used.

Before:

In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')

After:

In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')

When doing an MPS to MPS copy and source and destination have different dtypes, build and run an MPSGraph with one call to [MPSGraph castTensor:toType:name] to actually convert tensor data
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 23, 2022

🔗 Helpful links

❌ 5 New Failures

As of commit d99deb1 (more details on the Dr. CI page):

Expand to see more
  • 5/5 failures introduced in this PR

🕵️ 5 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (1/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-24T09:04:30.2065906Z The PR is introduc...m to confirm whether this change is wanted or not.
2022-05-24T09:04:30.2053682Z processing existing schema:  text(__torch__.torch.classes.profiling.SourceRef _0) -> (str _0)
2022-05-24T09:04:30.2054715Z processing existing schema:  count(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-24T09:04:30.2056062Z processing existing schema:  duration_ns(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-24T09:04:30.2057363Z processing existing schema:  source(__torch__.torch.classes.profiling.SourceStats _0) -> (__torch__.torch.classes.profiling.SourceRef _0)
2022-05-24T09:04:30.2059081Z processing existing schema:  line_map(__torch__.torch.classes.profiling.SourceStats _0) -> (Dict(int, __torch__.torch.classes.profiling.InstructionStats) _0)
2022-05-24T09:04:30.2059965Z processing existing schema:  __init__(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-24T09:04:30.2061124Z processing existing schema:  enable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-24T09:04:30.2062422Z processing existing schema:  disable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-24T09:04:30.2064030Z processing existing schema:  _dump_stats(__torch__.torch.classes.profiling._ScriptProfile _0) -> (__torch__.torch.classes.profiling.SourceStats[] _0)
2022-05-24T09:04:30.2065647Z processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (NoneType _0)
2022-05-24T09:04:30.2065906Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-05-24T09:04:30.2066022Z 
2022-05-24T09:04:30.2066096Z Broken ops: [
2022-05-24T09:04:30.2066270Z 	prims::zeta(Tensor self, Tensor other) -> (Tensor)
2022-05-24T09:04:30.2066415Z 	prims::bessel_i1(Tensor self) -> (Tensor)
2022-05-24T09:04:30.2066543Z 	prims::bessel_i0(Tensor self) -> (Tensor)
2022-05-24T09:04:30.2066603Z ]
2022-05-24T09:04:30.3064393Z + cleanup
2022-05-24T09:04:30.3064474Z + retcode=1
2022-05-24T09:04:30.3064537Z + set +x
2022-05-24T09:04:30.3117040Z ##[error]Process completed with exit code 1.

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge) (2/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-24T10:16:21.9592047Z RuntimeError: test_cpp_extensions_aot_no_ninja failed!
2022-05-24T10:16:21.6861804Z cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
2022-05-24T10:16:21.6870779Z g++ -pthread -B /opt/conda/compiler_compat -Wl,--sysroot=/ -pthread -shared -B /opt/conda/compiler_compat -L/opt/conda/lib -Wl,-rpath=/opt/conda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.7/extension.o -L/opt/conda/lib/python3.7/site-packages/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.7/torch_test_cpp_extension/cpp.cpython-37m-x86_64-linux-gnu.so
2022-05-24T10:16:21.6899912Z OMP: Error #15: Initializing libiomp5.so, but found unknown library already initialized.
2022-05-24T10:16:21.6900867Z OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.
2022-05-24T10:16:21.7688495Z error: command '/opt/cache/bin/g++' failed with exit code -6
2022-05-24T10:16:21.9587459Z Traceback (most recent call last):
2022-05-24T10:16:21.9587734Z   File "test/run_test.py", line 1074, in <module>
2022-05-24T10:16:21.9589720Z     main()
2022-05-24T10:16:21.9590015Z   File "test/run_test.py", line 1052, in main
2022-05-24T10:16:21.9591739Z     raise RuntimeError(err_message)
2022-05-24T10:16:21.9592047Z RuntimeError: test_cpp_extensions_aot_no_ninja failed!
2022-05-24T10:16:22.2052187Z 
2022-05-24T10:16:22.2052514Z real	69m56.629s
2022-05-24T10:16:22.2052900Z user	182m19.630s
2022-05-24T10:16:22.2053190Z sys	12m47.997s
2022-05-24T10:16:22.2053459Z + cleanup
2022-05-24T10:16:22.2053619Z + retcode=1
2022-05-24T10:16:22.2053765Z + set +x
2022-05-24T10:16:22.2086215Z ##[error]Process completed with exit code 1.
2022-05-24T10:16:22.2206147Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-24T10:16:22.2206399Z with:

See GitHub Actions build linux-binary-manywheel / manywheel-py3_7-rocm5_1_1-test (3/5)

Step: "Download Build Artifacts" (full log | diagnosis details | 🔁 rerun)

2022-05-24T10:14:53.6570352Z ##[error]Stream co... mismatch. Received 626477658 of 1354497565 bytes.
2022-05-24T10:12:47.7641096Z   PACKAGE_TYPE: manywheel
2022-05-24T10:12:47.7641352Z   DESIRED_CUDA: rocm5.1.1
2022-05-24T10:12:47.7641608Z   GPU_ARCH_VERSION: 5.1.1
2022-05-24T10:12:47.7641849Z   GPU_ARCH_TYPE: rocm
2022-05-24T10:12:47.7642151Z   DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1
2022-05-24T10:12:47.7642454Z   DESIRED_PYTHON: 3.7
2022-05-24T10:12:47.7642732Z   DOCKER_HOST: unix:///run/user/1502/docker.sock
2022-05-24T10:12:47.7643009Z ##[endgroup]
2022-05-24T10:12:48.5200959Z Found 1 objects with prefix pytorch/pytorch/2376623312/1/manywheel-py3_7-rocm5_1_1/
2022-05-24T10:12:48.5203761Z Starting download (1/1): /home/pytorchci/actions-runner/_work/_temp/artifacts/torch-1.13.0.dev20220524+rocm5.1.1-cp37-cp37m-linux_x86_64.whl
2022-05-24T10:14:53.6570352Z ##[error]Stream content length mismatch. Received 626477658 of 1354497565 bytes.
2022-05-24T10:14:53.6758088Z ##[group]Run # ignore expansion of "docker ps -q" since it could be empty
2022-05-24T10:14:53.6758906Z �[36;1m# ignore expansion of "docker ps -q" since it could be empty�[0m
2022-05-24T10:14:53.6759544Z �[36;1m# shellcheck disable=SC2046�[0m
2022-05-24T10:14:53.6760120Z �[36;1mdocker stop $(docker ps -q) || true�[0m
2022-05-24T10:14:53.6760692Z �[36;1m# Prune all of the docker images�[0m
2022-05-24T10:14:53.6761228Z �[36;1mdocker system prune -af�[0m
2022-05-24T10:14:53.6798270Z shell: /bin/bash -e {0}
2022-05-24T10:14:53.6798702Z env:
2022-05-24T10:14:53.6799257Z   ALPINE_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine
2022-05-24T10:14:53.6799874Z   ANACONDA_USER: pytorch

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu) (4/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-24T09:30:17.8599054Z FAIL [6.641s]: tes...hook_allreduce (__main__.TestDistBackendWithSpawn)
2022-05-24T09:30:17.4888198Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.4890350Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.5020417Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.5023157Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.5029151Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.5032513Z INFO:torch.nn.parallel.distributed:Reducer buckets have been rebuilt in this iteration.
2022-05-24T09:30:17.8597498Z ok (3.538s)
2022-05-24T09:30:17.8598197Z     test_ddp_buffer_hook_allreduce succeeded - num_retries_left: 0
2022-05-24T09:30:17.8598480Z 
2022-05-24T09:30:17.8598654Z ======================================================================
2022-05-24T09:30:17.8599054Z FAIL [6.641s]: test_ddp_buffer_hook_allreduce (__main__.TestDistBackendWithSpawn)
2022-05-24T09:30:17.8599571Z ----------------------------------------------------------------------
2022-05-24T09:30:17.8599913Z Traceback (most recent call last):
2022-05-24T09:30:17.8600472Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 501, in wrapper
2022-05-24T09:30:17.8600858Z     self._join_processes(fn)
2022-05-24T09:30:17.8601445Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 724, in _join_processes
2022-05-24T09:30:17.8601883Z     self._check_return_codes(elapsed_time)
2022-05-24T09:30:17.8602444Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py", line 783, in _check_return_codes
2022-05-24T09:30:17.8602854Z     i, first_process.exitcode, p.exitcode
2022-05-24T09:30:17.8603438Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 2255, in assertEqual
2022-05-24T09:30:17.8603959Z     msg=(lambda generated_msg: f"{generated_msg} : {msg}") if isinstance(msg, str) and self.longMessage else msg,

See GitHub Actions build trunk / macos-11-py3-x86-64 / test (default, 1, 2, macos-12) (5/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-24T11:12:46.2456460Z FAIL [1.489s]: tes...ented_with_fallback (__main__.TestFallbackWarning)
2022-05-24T11:12:46.2455050Z ERROR [0.016s]: test_smooth_l1_loss_reduction_sum (__main__.TestSmoothL1Loss)
2022-05-24T11:12:46.2455310Z ----------------------------------------------------------------------
2022-05-24T11:12:46.2455410Z Traceback (most recent call last):
2022-05-24T11:12:46.2455560Z   File "test_mps.py", line 1334, in test_smooth_l1_loss_reduction_sum
2022-05-24T11:12:46.2455680Z     self._smooth_l1_loss_helper(reduction="sum")
2022-05-24T11:12:46.2455800Z   File "test_mps.py", line 1312, in _smooth_l1_loss_helper
2022-05-24T11:12:46.2456050Z     input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
2022-05-24T11:12:46.2456170Z RuntimeError: Invalid buffer size: 112 bytes
2022-05-24T11:12:46.2456180Z 
2022-05-24T11:12:46.2456280Z ======================================================================
2022-05-24T11:12:46.2456460Z FAIL [1.489s]: test_warn_on_not_implemented_with_fallback (__main__.TestFallbackWarning)
2022-05-24T11:12:46.2456700Z ----------------------------------------------------------------------
2022-05-24T11:12:46.2456810Z Traceback (most recent call last):
2022-05-24T11:12:46.2456970Z   File "test_mps.py", line 4111, in test_warn_on_not_implemented_with_fallback
2022-05-24T11:12:46.2457070Z     subprocess.check_output(
2022-05-24T11:12:46.2458320Z subprocess.CalledProcessError: Command '['/Users/runner/miniconda3/envs/build/bin/python', '-W', 'all', '-c', '\nimport os\n# MUST happen before pytorch\'s import\nos.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"\nimport warnings\n\nwith warnings.catch_warnings(record=True) as w:\n    import torch\n\nif len(w) > 0:\n    exit(1)\n\n# This should run just fine and raise warning about perf\nwith warnings.catch_warnings(record=True) as w:\n    torch.eye(2, device=\'mps\')\n\nif len(w) != 1:\n    exit(2)\n\n']' returned non-zero exit status 1.
2022-05-24T11:12:46.2458530Z 
2022-05-24T11:12:46.2458720Z During handling of the above exception, another exception occurred:
2022-05-24T11:12:46.2458730Z 
2022-05-24T11:12:46.2458830Z Traceback (most recent call last):
2022-05-24T11:12:46.2458990Z   File "test_mps.py", line 4119, in test_warn_on_not_implemented_with_fallback

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@lhoenig
Copy link
Contributor Author

lhoenig commented May 23, 2022

Here is a small program I used to test the approach: https://gist.github.com/lhoenig/bb0636548b2abbe690e81a3fad139ee3

@albanD albanD requested a review from kulinseth May 23, 2022 18:50
@albanD
Copy link
Collaborator

albanD commented May 23, 2022

@kulinseth does this work better than your WIP solution?

Also if we go with this, we should add tests.

@albanD albanD added ciflow/trunk Trigger trunk jobs on your pull request ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR labels May 23, 2022
@lhoenig
Copy link
Contributor Author

lhoenig commented May 23, 2022

Just added commits that also fix dtype conversion in copy_from_mps_ and copy_to_mps_. The idea for copy_to_mps_ is simply to convert the tensor on the source device before copying, and for copy_from_mps_ the MTLBuffer gets converted prior to copying using the same approach as copy_kernel_mps, just inplace, i.e. writing back to the same MPSGraphTensorData. I'm not completely sure if that is legal, can't find any documentation about it, but it works nicely.

This way the offsets in the blit copy calls can still be used in both transfer copy functions.

Before:

In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1067869824, dtype=torch.int32)

In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1067869798, device='mps:0', dtype=torch.int32)

After:

In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1, dtype=torch.int32)

In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1, device='mps:0', dtype=torch.int32)

@kulinseth
Copy link
Collaborator

Just added commits that also fix dtype conversion in copy_from_mps_ and copy_to_mps_. The idea for copy_to_mps_ is simply to convert the tensor on the source device before copying, and for copy_from_mps_ the MTLBuffer gets converted prior to copying using the same approach as copy_kernel_mps, just inplace, i.e. writing back to the same MPSGraphTensorData. I'm not completely sure if that is legal, can't find any documentation about it, but it works nicely.

This way the offsets in the blit copy calls can still be used in both transfer copy functions.

Before:

In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1067869824, dtype=torch.int32)

In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1067869798, device='mps:0', dtype=torch.int32)

After:

In [4]: pt.tensor(1.3, device="mps").to("cpu", pt.int)
Out[4]: tensor(1, dtype=torch.int32)

In [5]: pt.tensor(1.3, device="cpu").to("mps", pt.int)
Out[5]: tensor(1, device='mps:0', dtype=torch.int32)

Thanks @lhoenig , the fix makes sense. I am testing that, right now. The casting is indeed required if the types are different when copying over. I will have some comments after i finish testing. Will post here.

@kulinseth
Copy link
Collaborator

@kulinseth does this work better than your WIP solution?

Also if we go with this, we should add tests.

Thanks @albanD . Totally agree, please add tests in test_mps.py file. The ones you have in the description will be good.

@kulinseth
Copy link
Collaborator

@kulinseth does this work better than your WIP solution?

Also if we go with this, we should add tests.

This is a good change, lets go with this. @lhoenig , please add the requested changes and we should be good.

@lhoenig
Copy link
Contributor Author

lhoenig commented May 23, 2022

@kulinseth does this work better than your WIP solution?
Also if we go with this, we should add tests.

This is a good change, lets go with this. @lhoenig , please add the requested changes and we should be good.

Awesome, thanks for the fast review, will get to it quickly!

@lhoenig
Copy link
Contributor Author

lhoenig commented May 23, 2022

@kulinseth Ok, done for now, have also added a test in test_mps.py. Let me know if anything more comes up!

kulinseth
kulinseth previously approved these changes May 24, 2022
Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the lintrunner/flake8 issues, but otherwise it looks good.

@lhoenig, i am seeing regression on our tests.

@kulinseth kulinseth self-requested a review May 24, 2022 00:13
@kulinseth kulinseth dismissed their stale review May 24, 2022 00:31

I am seeing regressions in test_mps.

@kulinseth
Copy link
Collaborator

Please fix the lintrunner/flake8 issues, but otherwise it looks good.

@lhoenig, i am seeing regression on our tests.

Its failing with:

ERROR: test_baddbmm (__main__.TestMPS)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_mps.py", line 327, in test_baddbmm
    helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
  File "test/test_mps.py", line 321, in helper
    output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
TypeError: Trying to convert UNKNOWN_SCALAR to the MPS backend but it does not have support for that dtype.

@kulinseth
Copy link
Collaborator

@lhoenig , I have some fixes, can i upload as a commit to this PR ?

@lhoenig
Copy link
Contributor Author

lhoenig commented May 24, 2022

@lhoenig , I have some fixes, can i upload as a commit to this PR ?

Sure!

@lhoenig
Copy link
Contributor Author

lhoenig commented May 24, 2022

@kulinseth I too uploaded fixes now - hope thats ok!
The test failures were due to src being unitialized in the code paths in copy_from_mps_ and copy_to_mps_ where !src_.is_contiguous() and then the gatherTensor is used.

In this codepath, src (the one without the underscore) is not initialized. Its scalar_type() was therefore UNKNOWN_SCALAR. As the src and dst arguments of copy_cast_mps are only used for shapes and dtypes anyways, I just pass src_ and dst_ to it now instead of src and dst. The shapes and dtypes should not be changed by the preprocessing (making them contiguous).

@kulinseth
Copy link
Collaborator

@kulinseth I too uploaded fixes now - hope thats ok! The test failures were due to src being unitialized in the code paths in copy_from_mps_ and copy_to_mps_ where !src_.is_contiguous() and then the gatherTensor is used.

In this codepath, src (the one without the underscore) is not initialized. Its scalar_type() was therefore UNKNOWN_SCALAR. As the src and dst arguments of copy_cast_mps are only used for shapes and dtypes anyways, I just pass src_ and dst_ to it now instead of src and dst. The shapes and dtypes should not be changed by the preprocessing (making them contiguous).

Awesome, these were the fixes i had locally. Now all the tests are passing.

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @lhoenig , thanks for the fix.

@danieldk
Copy link
Contributor

FWIW: I also tested this PR and now our transformer models work (with one additional fix). Thanks a bunch!

@kulinseth
Copy link
Collaborator

The issues in the failure are not related to this PR:

OMP: Error #15: Initializing libiomp5.so, but found unknown library already initialized.
[26385](https://github.com/pytorch/pytorch/runs/6570181652?check_suite_focus=true#step:9:26386)
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see [http://www.intel.com/software/products/support/.](http://www.intel.com/software/products/support/)
[26386](https://github.com/pytorch/pytorch/runs/6570181652?check_suite_focus=true#step:9:26387)
error: command '/opt/cache/bin/g++' failed with exit code -6

@kulinseth
Copy link
Collaborator

@pytorchbot merge this please.

@pytorchmergebot
Copy link
Collaborator

Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:lazysjb,HDCharles,bzinodev,huangyi1979,protonu, ...
Raised by https://github.com/pytorch/pytorch/actions/runs/2379945480

@malfet malfet added this to the 1.12.0 milestone May 24, 2022
@malfet
Copy link
Contributor

malfet commented May 24, 2022

@pytorchbot merge this

@github-actions
Copy link
Contributor

Hey @lhoenig.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@lhoenig lhoenig deleted the lhoenig-mps-dtype-conversion branch May 24, 2022 20:37
facebook-github-bot pushed a commit that referenced this pull request May 26, 2022
Summary:
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.

Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')
```

After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')
```

Pull Request resolved: #78092
Approved by: https://github.com/kulinseth, https://github.com/malfet

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a52bfe2c5d8588b8f9e83e0beecdd18a1d672d0e

Reviewed By: mehtanirav

Differential Revision: D36668700

fbshipit-source-id: 380086b940c6a1c83e5a9d1cfb5cd40802153c56
atalman pushed a commit to atalman/pytorch that referenced this pull request Jun 6, 2022
Fixes pytorch#78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.

Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')
```

After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')
```
Pull Request resolved: pytorch#78092
Approved by: https://github.com/kulinseth, https://github.com/malfet
malfet pushed a commit that referenced this pull request Jun 7, 2022
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.

Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')
```

After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')
```
Pull Request resolved: #78092
Approved by: https://github.com/kulinseth, https://github.com/malfet

(cherry picked from commit a52bfe2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged open source topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MPS: No conversion of Tensor datatype possible

8 participants