Skip to content

pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood.#66092

Closed
nikitaved wants to merge 12 commits intomasterfrom
nikitaved/pinv_backward
Closed

pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood.#66092
nikitaved wants to merge 12 commits intomasterfrom
nikitaved/pinv_backward

Conversation

@nikitaved
Copy link
Copy Markdown
Collaborator

@nikitaved nikitaved commented Oct 4, 2021

Fixes #65911. Also enables complex support/tests for linalg_pinv in OpInfo.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @walterddr @IvanYashchuk @xwang233

@nikitaved nikitaved added module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul complex_autograd ci/slow-gradcheck labels Oct 4, 2021
@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Oct 4, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/2b5431eed4d51d4b7566af98116cbc3c649a0978/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
puretorch-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

Comment on lines +8950 to +8951
# Only large tensors show issues with implicit backward used prior to
# explicit backward implementation.
Copy link
Copy Markdown
Collaborator Author

@nikitaved nikitaved Oct 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note: large tensors of low rank. In my environment I had to create a 1-rank 30x30 matrix to see issues with repeated "zeros" in the backward of SVD.

@ezyang ezyang removed their request for review October 4, 2021 23:16
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Oct 4, 2021

not sure appropriate FB reviewer has been tagged yet

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Oct 5, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2b5431e (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Oct 08 11:26:22 RuntimeError: tensorflow/compil...'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:22 Exception in device=CPU:1: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:22 Traceback (most recent call last):
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
Oct 08 11:26:22     _start_fn(index, pf_cfg, fn, args)
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
Oct 08 11:26:22     fn(gindex, *args)
Oct 08 11:26:22   File "/var/lib/jenkins/workspace/xla/test/test_mp_rendezvous.py", line 22, in _mp_fn
Oct 08 11:26:22     replicas=replicas)
Oct 08 11:26:22   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/core/xla_model.py", line 875, in rendezvous
Oct 08 11:26:22     return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
Oct 08 11:26:22 RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'rendezvous_test.0': Connection reset by peer (14)
Oct 08 11:26:23 Traceback (most recent call last):
Oct 08 11:26:23   File "/var/lib/jenkins/workspace/xla/test/test_mp_rendezvous.py", line 35, in <module>
Oct 08 11:26:23     xmp.spawn(_mp_fn, args=())
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.10-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 394, in spawn
Oct 08 11:26:23     start_method=start_method)
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
Oct 08 11:26:23     while not context.join():
Oct 08 11:26:23   File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 144, in join
Oct 08 11:26:23     exit_code=exitcode
Oct 08 11:26:23 torch.multiprocessing.spawn.ProcessExitedException: process 3 terminated with exit code 17

XLA failure

Job pytorch_xla_linux_bionic_py3_6_clang9_test is failing. Please create an issue with title prefixed by [PT_BREAK] in pytorch/xla and link to to this PR. If you have questions, please reach out to @ailzhang / @dlibenzi / @JackCaoG.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@nikitaved nikitaved changed the title pinv: forward/backward AD which is Frechet-differentiable in a rank-preserving neighborhood. pinv: forward/backward AD which is Frechet-defined in a rank-preserving neighborhood. Oct 5, 2021
Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks both for finding the slick formula for this backward and the rather compact implementation!

Comment on lines +2120 to +2121
# Note that by making the columns of `a` and `b` orthonormal we make sure
# that the product matrix `a @ b.t()` has condition number 1.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This saves us a lot of pain in future debugging.

Now, this note is slightly incorrect. The resulting matrix will have singular values 0 and 1, so the condition number will be infinite! Perhaps you mean that it has condition number 1 when restricted to its image?

Copy link
Copy Markdown
Collaborator Author

@nikitaved nikitaved Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly in the image, correct, so that pinv is stable.

Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
sample_inputs_func=sample_inputs_linalg_pinv_singular,
# Only large tensors show issues with implicit backward used prior to
# explicit backward implementation.
decorators=[slowTest, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the slowTest decorator working as expected here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD It will apply the slowTest decorator to EVERY test generated by this OpInfo

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Oct 6, 2021

Cool! Do you have before/after perf numbers for the autograd, @nikitaved?

@nikitaved
Copy link
Copy Markdown
Collaborator Author

@mruberry, I did run some benchmarks and surprisingly this PR also improves performance.

This PR, cpu float32:

shape: (10, 10), device: cpu, dtype: torch.float32                                                                                                                                                                 
29.8 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)                                                                                                                                           
                                                                                                                                                                                                                   
shape: (1000, 10, 10), device: cpu, dtype: torch.float32                                                                                                                                                           
797 µs ± 7.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (100, 100), device: cpu, dtype: torch.float32                                                                                                                                                               
273 µs ± 3.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (1000, 100, 100), device: cpu, dtype: torch.float32                                                                                                                                                         
56.4 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                              
                                                                                                                                                                                                                   
shape: (1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                             
11.7 ms ± 96.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                         
159 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)         

Master, cpu float32:

shape: (10, 10), device: cpu, dtype: torch.float32                                                                                                                                                                 
86.3 µs ± 3.55 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)                                                                                                                                          
                                                                                                                                                                                                                   
shape: (1000, 10, 10), device: cpu, dtype: torch.float32                                                                                                                                                           
2.23 ms ± 398 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)                                                                                                                                             
                                                                                                                                                                                                                   
shape: (100, 100), device: cpu, dtype: torch.float32                                                                                                                                                               
535 µs ± 33.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)                                                                                                                                            
                                                                                                                                                                                                                   
shape: (1000, 100, 100), device: cpu, dtype: torch.float32                                                                                                                                                         
174 ms ± 6.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                              
                                                                                                                                                                                                                   
shape: (1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                             
26.9 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)                                                                                                                                             
                                                                                                                                                                                                                   
shape: (10, 1000, 1000), device: cpu, dtype: torch.float32                                                                                                                                                         
392 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)     

This PR, cuda float32:

shape: (10, 10), device: cuda, dtype: torch.float32
111 µs ± 3.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

shape: (1000, 10, 10), device: cuda, dtype: torch.float32
332 µs ± 998 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (100, 100), device: cuda, dtype: torch.float32
111 µs ± 772 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

shape: (1000, 100, 100), device: cuda, dtype: torch.float32
7.25 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (1000, 1000), device: cuda, dtype: torch.float32
3.21 ms ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
29.8 ms ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Master, cuda float32:

shape: (10, 10), device: cuda, dtype: torch.float32
282 µs ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (1000, 10, 10), device: cuda, dtype: torch.float32
565 µs ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (100, 100), device: cuda, dtype: torch.float32
312 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

shape: (1000, 100, 100), device: cuda, dtype: torch.float32
11.8 ms ± 41.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (1000, 1000), device: cuda, dtype: torch.float32
4.72 ms ± 32.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

shape: (10, 1000, 1000), device: cuda, dtype: torch.float32
42.4 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@lezcano
Copy link
Copy Markdown
Collaborator

lezcano commented Oct 8, 2021

Faster and correct! There's no better combination than that :)

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Can you fix the last lint (EDIT: Ho it looks like the job itself failed...)and I'll merge this.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@github-actions github-actions Bot deleted the nikitaved/pinv_backward branch February 13, 2024 01:58
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…ng neighborhood. (pytorch#66092)

Summary:
Fixes pytorch#65911. Also enables complex support/tests for `linalg_pinv` in OpInfo.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#66092

Reviewed By: ejguan

Differential Revision: D31503072

Pulled By: albanD

fbshipit-source-id: 52018e826826ae62beaad76becb5edf880be253f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed complex_autograd module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

pinv could be differentiable on a wider range of inputs

7 participants