Skip to content

Create linalg.tensordot#63478

Closed
antocuni wants to merge 39 commits intopytorch:masterfrom
antocuni:antocuni/linalg-tensordot
Closed

Create linalg.tensordot#63478
antocuni wants to merge 39 commits intopytorch:masterfrom
antocuni:antocuni/linalg-tensordot

Conversation

@antocuni
Copy link
Copy Markdown
Contributor

@antocuni antocuni commented Aug 18, 2021

PR stack (ghstack-style but done manually)

Fixes #61649.
As the title says, this introduces torch.linalg.tensordot and makes torch.tensordot an alias to it.
However, tensordot is a bit different than most other linalg operators so it is worth some discussion on how to proceed.
Differently than the other operators, the user entry-point of tensordot is implemented in python, in torch/functional.py, which in turns calls torch._C._linalg.linalg_tensordot.

What I did so far in this PR is:

  1. introduce the C++-level at::linalg_tensordot and make it the main implementation
  2. make at::tensordot a C++ alias to at::linalg_tensordot
  3. leave the python implementation inside torch/functional.py
  4. make it available as both torch.tensordot and torch.linalg.tensordot
  5. update all the docs to use torch.linalg.tensordot everywhere
  6. update all the existing code to use torch.linalg.tensordot everywhere

I am not fully sure about point (3) though. One possible alternative is to move the implementation to e.g. a newly created torch/linalg/functional.py, but this creates more problems because there are other pieces of code and/or tests which expect all these functions to be inside torch/functional.py.

The other problem of the current solution is that since torch.tensordot and torch.linalg.tensordot are the very same python-level function, they have the very same docstring and thus the same generated documentation page. Maybe it is possible to tweak sphinx to use a custom text for torch.tensordot instead of using the docstring, but I haven't investigate yet.

I opted for the current solution because it was the easiest to implement and I wanted to get some feedback before spending time in more complex ones. Do you think it is enough or we should try to differentiate more between torch.tensordot and torch.linalg.tensordot?

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Aug 18, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 77a19da (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-xenial-cuda11.3-py3.6-gcc7 / build (1/1)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun)

2021-12-01T16:42:49.1529770Z �[36;1m echo "ERR...t available for the merge-base of your branch"�[0m
2021-12-01T16:42:49.1523532Z �[36;1mfi�[0m
2021-12-01T16:42:49.1524010Z �[36;1m# Covers the case where a previous tag doesn't exist for the tree�[0m
2021-12-01T16:42:49.1524759Z �[36;1m# this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly�[0m
2021-12-01T16:42:49.1525457Z �[36;1mif ! git rev-parse "$MERGE_BASE:.circleci/docker"; then�[0m
2021-12-01T16:42:49.1526217Z �[36;1m  echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit"�[0m
2021-12-01T16:42:49.1526822Z �[36;1m  exit 1�[0m
2021-12-01T16:42:49.1527138Z �[36;1mfi�[0m
2021-12-01T16:42:49.1527626Z �[36;1mPREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker")�[0m
2021-12-01T16:42:49.1528410Z �[36;1m# If no image exists but the hash is the same as the previous hash then we should error out here�[0m
2021-12-01T16:42:49.1529054Z �[36;1mif [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then�[0m
2021-12-01T16:42:49.1529770Z �[36;1m  echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch"�[0m
2021-12-01T16:42:49.1530525Z �[36;1m  echo "       contact the PyTorch team to restore the original images"�[0m
2021-12-01T16:42:49.1531006Z �[36;1m  exit 1�[0m
2021-12-01T16:42:49.1531306Z �[36;1mfi�[0m
2021-12-01T16:42:49.1531711Z �[36;1mecho ::set-output name=rebuild::yes�[0m
2021-12-01T16:42:49.1542170Z shell: /usr/bin/bash -e {0}
2021-12-01T16:42:49.1542527Z env:
2021-12-01T16:42:49.1543092Z   BUILD_ENVIRONMENT: linux-xenial-cuda11.3-py3.6-gcc7
2021-12-01T16:42:49.1544334Z   DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7
2021-12-01T16:42:49.1545576Z   SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
2021-12-01T16:42:49.1546518Z   XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@antocuni antocuni added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: python array api Issues related to the Python Array API labels Aug 18, 2021
@antocuni antocuni changed the title WIP: Create linalg.tensordot Create linalg.tensordot Aug 18, 2021
@antocuni antocuni marked this pull request as ready for review August 18, 2021 15:40
@ejguan ejguan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 18, 2021
@ezyang ezyang removed their request for review August 18, 2021 23:57
Comment thread aten/src/ATen/native/native_functions.yaml
python_module: linalg
variants: function

- func: linalg_tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth investing the time and port the implementation to be a "structured kernel". See https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#structured-keyword-proposal

Comment thread torch/functional.py
Comment thread torch/linalg/__init__.py Outdated
Comment thread test/test_linalg.py Outdated
Comment thread torch/__init__.py
Comment thread torch/functional.py
from torch import _VF
from torch._C import _linalg # type: ignore[attr-defined]

__all__ = [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing tensordot from __all__ would decouple torch.tensordot and torch.functional.tensordot. Then you'd get it as an alias to torch._C._linalg.linalg_tensordot and it should be possible to add a separate docstring in _torch_docs.py that would be simply referring to torch.linalg.tensordot.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is doable: if I make torch.tensordot an alias to torch._C._linalg.linalg_tensordot it would have a different signature and behavior as torch.linalg.tensordot, because the latter has a python implementation.

- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg

- func: linalg_tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature in Python is tensordot(a, b, dims=2, out=None) where dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor). Can all the logic on dims argument be ported from Python to C++? If not maybe the Python signature should be modified so that the mapping from Python to C++ is more direct and no glue code from functional.py is neccessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, I don't know the answer to this question. The fact that the logic was originally written in Python makes me to suspect that there was a good reason to do that but I don't really know if the original reason still holds.
E.g., maybe the dispatcher cannot handle such a complex overload?

Comment thread docs/source/linalg.rst Outdated
Comment thread docs/source/linalg.rst
@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Sep 30, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/antocuni/pytorch/blob/77a19da83e1b3dc98aa8eec887ac6c6f7fa471dd/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

Comment thread torch/__init__.py Outdated
Comment thread torch/fx/operator_schemas.py Outdated
Comment thread torch/jit/_builtins.py Outdated
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Oct 1, 2021

Creating a linalg.tensordot alias is unfortunately tricky because of all the special handling we have for functionals.

To add it we'd probably need to create linalg.functional.py and then also bind the function defined there in functional.py?

There are some other issues with tensordot itself that may be interesting to address first:

  • making it structured
  • correctness
  • docs

I filed #65989 for the correctness issue (and a suggestion to fix the docs). PRs addressing these three issues would be extremely interesting if you'd prefer to start with those, @antocuni.

For the docs in particular it'd be helpful to borrow a page from NumPy and show the common cases as well as their equivalent computations. The current torch docs are nearly indecipherable because they use undefined terms and the mathematical formula is (at best) very obtuse.

@antocuni
Copy link
Copy Markdown
Contributor Author

antocuni commented Oct 7, 2021

Creating a linalg.tensordot alias is unfortunately tricky because of all the special handling we have for functionals.

To add it we'd probably need to create linalg.functional.py and then also bind the function defined there in functional.py?

Done. I just pushed, let's see if the tests pass.

There are some other issues with tensordot itself that may be interesting to address first:

  • making it structured

I already started/tried in #64819 but it can't be done at the moment because apparently CompositeImplicitAutograd is not compatible with structured kernels. See the PR for more details.

  • correctness
  • docs

I filed #65989 for the correctness issue (and a suggestion to fix the docs). PRs addressing these three issues would be extremely interesting if you'd prefer to start with those, @antocuni.

I'm confused: according to #65989 the correctness problem is already solved, isn't it? Is there anything else which should be done?

For the docs in particular it'd be helpful to borrow a page from NumPy and show the common cases as well as their equivalent computations. The current torch docs are nearly indecipherable because they use undefined terms and the mathematical formula is (at best) very obtuse.

Having better docs look always like a good idea, I'd be happy to work on that. But it doesn't sound like a blocker for this PR, can we merge it first?

Comment thread torch/jit/_builtins.py
Comment thread aten/src/ATen/autocast_mode.cpp
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good, @antocuni. I made an inline comment about preserving BC. There are a few other places that use tensordot that may need an update:

https://github.com/pytorch/pytorch/blob/master/docs/source/amp.rst

and

("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),

(we should be sure both tensordot and linalg.tensordot are handled properly)

and the last change I'd like to suggest is that while linalg.tensordot is the suggested user-facing op that internally we don't change the implementation of tensordot and actually alias linalg.tensordot to it. This might avoid some internal refactoring issues (that we don't expect to apply in the future).

@antocuni
Copy link
Copy Markdown
Contributor Author

@mruberry I applied the changes you suggested and moved the implementation of tensordot back to torch.functional

class M(torch.nn.Module):
def forward(self, x, y):
output = torch.tensordot(x, y, 2)
output = torch.linalg.tensordot(x, y, 2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep one redundant test for torch.tensordot()?

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @antocuni! Sorry to make you wait and thanks for your patience. I'm finally in the NYC office.

Overall this looks great; I added just a few inline comments for your review. Looking forward to hearing your thoughts!

Comment thread test/test_linalg.py
b = torch.randn(4, 5, 6, 7, device=device)
c = torch.tensordot(a, b, dims=2).cpu()
# case 4: dims is an integer
a = torch.randn(2, 3, 4, 5, dtype=dtype, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice dtype updates

Comment thread test/test_linalg.py Outdated
self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0)))
self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0)))

@dtypes(*floating_types())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably just run this in tf32 still unless there's a particular reason this would like to test double, too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there is no particular reason. I switched to testing only torch.float32 in 33345f7

Comment thread test/test_linalg.py Outdated

a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
# case 8: dims=0
a = torch.linalg.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensors need to use the device and dtype, too

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, added in d62803a

Comment thread torch/functional.py

if out is None:
return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined]
return torch._C._linalg.linalg_tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by this part still. This functional has a bunch of logic before calling torch._C._linalg.linalg_tensordot(), but when torch.linalg.tensordot() is called how does it implement all the above logic?

Copy link
Copy Markdown
Contributor Author

@antocuni antocuni Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.linalg.tensordot is the very same python function as torch.functional.tensordot:

>>> import torch
>>> torch.linalg.tensordot is torch.functional.tensordot
True

@antocuni
Copy link
Copy Markdown
Contributor Author

antocuni commented Dec 1, 2021

@mruberry hopefully addressed your latest comments. I also updated to the latest master and fixed conflicts (tests still running at the time of writing, hopefully they will be green)

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label May 21, 2022
@github-actions github-actions Bot closed this Jun 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: python array api Issues related to the Python Array API open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Move tensordot into torch.linalg

8 participants