Skip to content

Make __new__ on tensor subclasses match _make_subclass#73727

Closed
ezyang wants to merge 1 commit intogh/ezyang/1090/basefrom
gh/ezyang/1090/head
Closed

Make __new__ on tensor subclasses match _make_subclass#73727
ezyang wants to merge 1 commit intogh/ezyang/1090/basefrom
gh/ezyang/1090/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Mar 3, 2022

Stack from ghstack:

Previously, calling SubclassTensor(tensor) would give you a
SubclassTensor where the underlying at::Tensor was computed by an
alias() call. In particular, a grad_fn would be created in this
situation. This is usually not what people want, as the alias grad_fn
is oblivious to the subclass's semantics (and just as likely to be wrong)
and it means that you cannot use the constructor to directly create a
leaf SubclassTensor that requires_grad=True.

This PR changes the meaning of this call so that SubclassTensor(tensor)
is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor);
that is to say, the underlying at::Tensor is created by a detach() call
(deleting grad_fn), and furthermore the requires_grad defaults to False
(but you can set it explicitly afterwards). I keep exactly the old
behavior if you call a normal Tensor, which could be somewhat confusing
as it doesn't match exactly.

I'm not sure if this is completely correct. Here are some other ways
we could skin the cat:

  • detach(), but propagate requires_grad-ness. This lets an idiom like
    TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive
    thing.
  • detach(), ignore input requires_grad and also accept a requires_grad
    kwarg for setting requires_grad directly. This means you would write
    TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf
    node.
  • Same as above, but assert that requires_grad=False or that we are
    in a no_grad mode. This would remind users that if they want a
    non-leaf tensor subclass, they are obligated to think about the
    autograd semantics for this boundary.

Signed-off-by: Edward Z. Yang ezyang@fb.com

Differential Revision: D34615319

Previously, calling SubclassTensor(tensor) would give you a
SubclassTensor where the underlying at::Tensor was computed by an
alias() call.  In particular, a grad_fn would be created in this
situation.  This is usually not what people want, as the alias grad_fn
is oblivious to the subclass's semantics (and just as likely to be wrong)
and it means that you cannot use the constructor to directly create a
leaf SubclassTensor that requires_grad=True.

This PR changes the meaning of this call so that SubclassTensor(tensor)
is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor);
that is to say, the underlying at::Tensor is created by a detach() call
(deleting grad_fn), and furthermore the requires_grad defaults to False
(but you can set it explicitly afterwards).  I keep exactly the old
behavior if you call a normal Tensor, which could be somewhat confusing
as it doesn't match exactly.

I'm not sure if this is completely correct. Here are some other ways
we could skin the cat:

- detach(), but propagate requires_grad-ness.  This lets an idiom like
  TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive
  thing.
- detach(), ignore input requires_grad and also accept a requires_grad
  kwarg for setting requires_grad directly.  This means you would write
  TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf
  node.
- Same as above, but assert that requires_grad=False or that we are
  in a no_grad mode.  This would remind users that if they want a
  non-leaf tensor subclass, they are obligated to think about the
  autograd semantics for this boundary.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/df853a17fbb7babeb103a1e6192bc3bcd6ca7f41/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 3, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit df853a1 (more details on the Dr. CI page):



🕵️ 10 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge) (1/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T05:37:32.7740876Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:37:32.7643273Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T05:37:32.7663311Z Test semantics for __torch_function__ for functions that take ... ok (0.004s)
2022-03-03T05:37:32.7675688Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T05:37:32.7682627Z this test exercises the functionality described in ... ok (0.002s)
2022-03-03T05:37:32.7687109Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T05:37:32.7691495Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T05:37:32.7727179Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.003s)
2022-03-03T05:37:32.7738021Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T05:37:32.7738338Z 
2022-03-03T05:37:32.7740245Z ======================================================================
2022-03-03T05:37:32.7740876Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:37:32.7741607Z ----------------------------------------------------------------------
2022-03-03T05:37:32.7742205Z Traceback (most recent call last):
2022-03-03T05:37:32.7742548Z   File "test_overrides.py", line 556, in test_base
2022-03-03T05:37:32.7742873Z     self.assertTrue(c._is_view())
2022-03-03T05:37:32.7743214Z AssertionError: False is not true
2022-03-03T05:37:32.7743424Z 
2022-03-03T05:37:32.7743766Z ----------------------------------------------------------------------
2022-03-03T05:37:32.7744164Z Ran 25 tests in 0.282s
2022-03-03T05:37:32.7744335Z 
2022-03-03T05:37:32.7744526Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu) (2/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T06:47:17.5203586Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:47:17.5095921Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T06:47:17.5110938Z Test semantics for __torch_function__ for functions that take ... ok (0.005s)
2022-03-03T06:47:17.5128830Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T06:47:17.5137526Z this test exercises the functionality described in ... ok (0.003s)
2022-03-03T06:47:17.5143490Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T06:47:17.5147731Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T06:47:17.5189784Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.004s)
2022-03-03T06:47:17.5202463Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T06:47:17.5202975Z 
2022-03-03T06:47:17.5203191Z ======================================================================
2022-03-03T06:47:17.5203586Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:47:17.5204398Z ----------------------------------------------------------------------
2022-03-03T06:47:17.5204745Z Traceback (most recent call last):
2022-03-03T06:47:17.5205090Z   File "test_overrides.py", line 556, in test_base
2022-03-03T06:47:17.5205360Z     self.assertTrue(c._is_view())
2022-03-03T06:47:17.5205669Z AssertionError: False is not true
2022-03-03T06:47:17.5205870Z 
2022-03-03T06:47:17.5206254Z ----------------------------------------------------------------------
2022-03-03T06:47:17.5206726Z Ran 25 tests in 1.749s
2022-03-03T06:47:17.5206890Z 
2022-03-03T06:47:17.5207054Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-bionic-rocm4.5-py3.7 / test (default, 1, 2, linux.rocm.gpu) (3/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T12:14:03.2320879Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T12:14:03.2223378Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T12:14:03.2237020Z Test semantics for __torch_function__ for functions that take ... ok (0.005s)
2022-03-03T12:14:03.2255890Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T12:14:03.2259303Z this test exercises the functionality described in ... ok (0.003s)
2022-03-03T12:14:03.2264744Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T12:14:03.2267800Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T12:14:03.2311835Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.004s)
2022-03-03T12:14:03.2318511Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T12:14:03.2319346Z 
2022-03-03T12:14:03.2319782Z ======================================================================
2022-03-03T12:14:03.2320879Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T12:14:03.2322480Z ----------------------------------------------------------------------
2022-03-03T12:14:03.2323555Z Traceback (most recent call last):
2022-03-03T12:14:03.2330366Z   File "test_overrides.py", line 556, in test_base
2022-03-03T12:14:03.2331479Z     self.assertTrue(c._is_view())
2022-03-03T12:14:03.2332356Z AssertionError: False is not true
2022-03-03T12:14:03.2333091Z 
2022-03-03T12:14:03.2333884Z ----------------------------------------------------------------------
2022-03-03T12:14:03.2334666Z Ran 25 tests in 0.679s
2022-03-03T12:14:03.2335029Z 
2022-03-03T12:14:03.2335336Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-bionic-py3.7-clang9 / test (noarch, 1, 1, linux.2xlarge) (4/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T05:52:36.1307082Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:52:36.1215419Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T05:52:36.1228861Z Test semantics for __torch_function__ for functions that take ... ok (0.004s)
2022-03-03T05:52:36.1241477Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T05:52:36.1248838Z this test exercises the functionality described in ... ok (0.002s)
2022-03-03T05:52:36.1253121Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T05:52:36.1257506Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T05:52:36.1294747Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.004s)
2022-03-03T05:52:36.1306341Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T05:52:36.1306605Z 
2022-03-03T05:52:36.1306734Z ======================================================================
2022-03-03T05:52:36.1307082Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:52:36.1307663Z ----------------------------------------------------------------------
2022-03-03T05:52:36.1308237Z Traceback (most recent call last):
2022-03-03T05:52:36.1308551Z   File "test_overrides.py", line 556, in test_base
2022-03-03T05:52:36.1308867Z     self.assertTrue(c._is_view())
2022-03-03T05:52:36.1309164Z AssertionError: False is not true
2022-03-03T05:52:36.1309334Z 
2022-03-03T05:52:36.1315165Z ----------------------------------------------------------------------
2022-03-03T05:52:36.1315705Z Ran 25 tests in 0.295s
2022-03-03T05:52:36.1315880Z 
2022-03-03T05:52:36.1316027Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge) (5/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T06:21:50.0186133Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:21:50.0062859Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T06:21:50.0087134Z Test semantics for __torch_function__ for functions that take ... ok (0.006s)
2022-03-03T06:21:50.0103914Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T06:21:50.0116948Z this test exercises the functionality described in ... ok (0.003s)
2022-03-03T06:21:50.0122760Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T06:21:50.0127901Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T06:21:50.0172740Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.004s)
2022-03-03T06:21:50.0185315Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T06:21:50.0185602Z 
2022-03-03T06:21:50.0185742Z ======================================================================
2022-03-03T06:21:50.0186133Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:21:50.0187003Z ----------------------------------------------------------------------
2022-03-03T06:21:50.0187411Z Traceback (most recent call last):
2022-03-03T06:21:50.0187997Z   File "test_overrides.py", line 556, in test_base
2022-03-03T06:21:50.0188388Z     self.assertTrue(c._is_view())
2022-03-03T06:21:50.0188721Z AssertionError: False is not true
2022-03-03T06:21:50.0188935Z 
2022-03-03T06:21:50.0189342Z ----------------------------------------------------------------------
2022-03-03T06:21:50.0189747Z Ran 25 tests in 0.385s
2022-03-03T06:21:50.0189926Z 
2022-03-03T06:21:50.0190068Z FAILED (failures=1, expected failures=3)

See GitHub Actions build win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu) (6/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T08:26:35.3666121Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T08:26:35.3521847Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T08:26:35.3539302Z Test semantics for __torch_function__ for functions that take ... ok (0.000s)
2022-03-03T08:26:35.3564212Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T08:26:35.3573624Z this test exercises the functionality described in ... ok (0.016s)
2022-03-03T08:26:35.3582350Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T08:26:35.3586771Z Test that errors raised in user implementations propagate correctly ... ok (0.000s)
2022-03-03T08:26:35.3648403Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.000s)
2022-03-03T08:26:35.3663071Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.000s)
2022-03-03T08:26:35.3664274Z 
2022-03-03T08:26:35.3664932Z ======================================================================
2022-03-03T08:26:35.3666121Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T08:26:35.3667541Z ----------------------------------------------------------------------
2022-03-03T08:26:35.3668641Z Traceback (most recent call last):
2022-03-03T08:26:35.3670321Z   File "test_overrides.py", line 556, in test_base
2022-03-03T08:26:35.3671462Z     self.assertTrue(c._is_view())
2022-03-03T08:26:35.3672498Z AssertionError: False is not true
2022-03-03T08:26:35.3673137Z 
2022-03-03T08:26:35.3928778Z ----------------------------------------------------------------------
2022-03-03T08:26:35.3929851Z Ran 25 tests in 0.938s
2022-03-03T08:26:35.3930341Z 
2022-03-03T08:26:35.3931109Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge) (7/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T05:35:02.3243972Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:35:02.3158930Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T05:35:02.3176281Z Test semantics for __torch_function__ for functions that take ... ok (0.004s)
2022-03-03T05:35:02.3187249Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T05:35:02.3193532Z this test exercises the functionality described in ... ok (0.002s)
2022-03-03T05:35:02.3197103Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T05:35:02.3200904Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T05:35:02.3234017Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.003s)
2022-03-03T05:35:02.3243185Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T05:35:02.3243454Z 
2022-03-03T05:35:02.3243604Z ======================================================================
2022-03-03T05:35:02.3243972Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:35:02.3244634Z ----------------------------------------------------------------------
2022-03-03T05:35:02.3245047Z Traceback (most recent call last):
2022-03-03T05:35:02.3245472Z   File "test_overrides.py", line 556, in test_base
2022-03-03T05:35:02.3245793Z     self.assertTrue(c._is_view())
2022-03-03T05:35:02.3246008Z AssertionError: False is not true
2022-03-03T05:35:02.3246135Z 
2022-03-03T05:35:02.3246317Z ----------------------------------------------------------------------
2022-03-03T05:35:02.3246546Z Ran 25 tests in 0.268s
2022-03-03T05:35:02.3246659Z 
2022-03-03T05:35:02.3246753Z FAILED (failures=1, expected failures=3)

See GitHub Actions build win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (8/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T06:32:40.5492223Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:32:40.5376173Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T06:32:40.5391331Z Test semantics for __torch_function__ for functions that take ... ok (0.000s)
2022-03-03T06:32:40.5408185Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T06:32:40.5417526Z this test exercises the functionality described in ... ok (0.000s)
2022-03-03T06:32:40.5423158Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T06:32:40.5429313Z Test that errors raised in user implementations propagate correctly ... ok (0.000s)
2022-03-03T06:32:40.5477130Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.016s)
2022-03-03T06:32:40.5490434Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.000s)
2022-03-03T06:32:40.5491050Z 
2022-03-03T06:32:40.5491451Z ======================================================================
2022-03-03T06:32:40.5492223Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T06:32:40.5493131Z ----------------------------------------------------------------------
2022-03-03T06:32:40.5493808Z Traceback (most recent call last):
2022-03-03T06:32:40.5494894Z   File "test_overrides.py", line 556, in test_base
2022-03-03T06:32:40.5495598Z     self.assertTrue(c._is_view())
2022-03-03T06:32:40.5496256Z AssertionError: False is not true
2022-03-03T06:32:40.5496671Z 
2022-03-03T06:32:40.5646369Z ----------------------------------------------------------------------
2022-03-03T06:32:40.5647023Z Ran 25 tests in 0.563s
2022-03-03T06:32:40.5647382Z 
2022-03-03T06:32:40.5647936Z FAILED (failures=1, expected failures=3)

See GitHub Actions build linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (9/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T05:36:19.2093341Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:36:19.2005022Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T05:36:19.2023506Z Test semantics for __torch_function__ for functions that take ... ok (0.004s)
2022-03-03T05:36:19.2034837Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T05:36:19.2040988Z this test exercises the functionality described in ... ok (0.002s)
2022-03-03T05:36:19.2044555Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T05:36:19.2048435Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T05:36:19.2083195Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.003s)
2022-03-03T05:36:19.2091912Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.001s)
2022-03-03T05:36:19.2092174Z 
2022-03-03T05:36:19.2092460Z ======================================================================
2022-03-03T05:36:19.2093341Z FAIL [0.001s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T05:36:19.2094104Z ----------------------------------------------------------------------
2022-03-03T05:36:19.2094524Z Traceback (most recent call last):
2022-03-03T05:36:19.2094935Z   File "test_overrides.py", line 556, in test_base
2022-03-03T05:36:19.2095327Z     self.assertTrue(c._is_view())
2022-03-03T05:36:19.2095678Z AssertionError: False is not true
2022-03-03T05:36:19.2095903Z 
2022-03-03T05:36:19.2096255Z ----------------------------------------------------------------------
2022-03-03T05:36:19.2096480Z Ran 25 tests in 0.284s
2022-03-03T05:36:19.2096589Z 
2022-03-03T05:36:19.2096686Z FAILED (failures=1, expected failures=3)

See GitHub Actions build win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge) (10/10)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-03T07:41:05.4383234Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T07:41:05.4261300Z   test_precedence_semantics (__main__.TestTorchFunctionOverride)
2022-03-03T07:41:05.4277336Z Test semantics for __torch_function__ for functions that take ... ok (0.000s)
2022-03-03T07:41:05.4294568Z   test_tensor_subclass_propagation (__main__.TestTorchFunctionOverride)
2022-03-03T07:41:05.4304400Z this test exercises the functionality described in ... ok (0.017s)
2022-03-03T07:41:05.4310068Z   test_user_implementation_raises (__main__.TestTorchFunctionOverride)
2022-03-03T07:41:05.4316665Z Test that errors raised in user implementations propagate correctly ... ok (0.001s)
2022-03-03T07:41:05.4367686Z   test_warn_on_invalid_torch_function (__main__.TestTorchFunctionWarning) ... ok (0.005s)
2022-03-03T07:41:05.4381517Z   test_wrap_torch_function (__main__.TestWrapTorchFunction) ... ok (0.000s)
2022-03-03T07:41:05.4382110Z 
2022-03-03T07:41:05.4382516Z ======================================================================
2022-03-03T07:41:05.4383234Z FAIL [0.000s]: test_base (__main__.TestTorchFunctionOverride)
2022-03-03T07:41:05.4384075Z ----------------------------------------------------------------------
2022-03-03T07:41:05.4384717Z Traceback (most recent call last):
2022-03-03T07:41:05.4386156Z   File "test_overrides.py", line 556, in test_base
2022-03-03T07:41:05.4386825Z     self.assertTrue(c._is_view())
2022-03-03T07:41:05.4387441Z AssertionError: False is not true
2022-03-03T07:41:05.4387816Z 
2022-03-03T07:41:05.4547760Z ----------------------------------------------------------------------
2022-03-03T07:41:05.4548615Z Ran 25 tests in 0.594s
2022-03-03T07:41:05.4548969Z 
2022-03-03T07:41:05.4549589Z FAILED (failures=1, expected failures=3)

🚧 1 ongoing upstream failure:

These were probably caused by upstream breakages that are not fixed yet.


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

ezyang added a commit that referenced this pull request Mar 3, 2022
Previously, calling SubclassTensor(tensor) would give you a
SubclassTensor where the underlying at::Tensor was computed by an
alias() call.  In particular, a grad_fn would be created in this
situation.  This is usually not what people want, as the alias grad_fn
is oblivious to the subclass's semantics (and just as likely to be wrong)
and it means that you cannot use the constructor to directly create a
leaf SubclassTensor that requires_grad=True.

This PR changes the meaning of this call so that SubclassTensor(tensor)
is equivalent to torch.Tensor._make_subclass(SubclassTensor, tensor);
that is to say, the underlying at::Tensor is created by a detach() call
(deleting grad_fn), and furthermore the requires_grad defaults to False
(but you can set it explicitly afterwards).  I keep exactly the old
behavior if you call a normal Tensor, which could be somewhat confusing
as it doesn't match exactly.

I'm not sure if this is completely correct. Here are some other ways
we could skin the cat:

- detach(), but propagate requires_grad-ness.  This lets an idiom like
  TensorSubclass(torch.empty(2, requires_grad=True)) do the intuitive
  thing.
- detach(), ignore input requires_grad and also accept a requires_grad
  kwarg for setting requires_grad directly.  This means you would write
  TensorSubclass(torch.empty(2), requires_grad=True) to create a leaf
  node.
- Same as above, but assert that requires_grad=False or that we are
  in a no_grad mode.  This would remind users that if they want a
  non-leaf tensor subclass, they are obligated to think about the
  autograd semantics for this boundary.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: df8e576
Pull Request resolved: #73727
@ezyang ezyang requested review from Chillee, bdhirsh and zou3519 March 3, 2022 04:49
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

TORCH_CHECK(type != &THPVariableType, "Cannot directly construct _TensorBase; subclass it and then construct that");
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs);
auto tensor = torch::utils::legacy_tensor_ctor(type, torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I read correctly that this constructor does not accept the "requires_grad" flag today at all? And thus can't be used to do Foo(torch.empty(2), requires_grad=True)?
In that case, I think I like the idea of raising an error if the input requires_grad and not in no_grad mode. We can even recommend to use a custom Function if they want to make this construction differentiable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based in your chat comments I was planning to readd the requires grad field, seemed like a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the final thing would be:

  • You have a requires_grad flag you can use to create a leaf that requires grad
  • The given Tensor can never require gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that gives me pause about forcing the given Tensor to not require gradients is that it is a bit tiresome for AOTAutograd when you get a requires_grad input; we actually do want to turn these into leaves so that we can compute gradients only to them. But I guess it is not too bad; it looks like `Tracer(x.detach(), requires_grad=x.requires_grad)

Comment on lines -108 to -110
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the new constructor work with wrapper Tensors? I guess I want to see an example of how this changes construction for the different type of tensor subclasses

Someone might want to create a wrapper Tensor subclass (DiagTensor?) that holds a vector and would want DiagTensor(blah) to work. Does this mean DiagTensor would override new ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WrapperTensor still has to define new; in general any tensor which doesn't match metadata doesn't work here.

@ezyang
Copy link
Contributor Author

ezyang commented Mar 3, 2022

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor Author

ezyang commented Mar 4, 2022

Based on the failing test case, I need to revise my plan.

  • Today, users can define "metadata" only tensor subclasses without explicitly overriding __new__. These tensor subclasses participate in autograd with their inputs in the normal way. We can't change this behavior without breaking BC.
  • Similarly to the above, without an explicit detach, I need view metadata to be preserved across the subclass creation boundary, because the subclass is indeed a view of the tensor you pass in, and we should accurately track this by default.

So here is what I suggest we do.

  • As before, __new__ calls alias on the input and sets up a grad_fn if requires grad. If a user needs to override the autograd behavior, they just make sure the constructor is only called in a custom function (where grad is disabled so you end up with a requires_grad=False tensor).
  • Sometimes, you will want to turn the tensor subclass into a leaf on creation. For example, if you write Subclass(torch.zeros(2, requires_grad=True)) in many cases you don't want to backprop into the now dead leaf. So there should be a convenient api for this

@ezyang
Copy link
Contributor Author

ezyang commented Mar 4, 2022

So we will have these modes of use:

# Treat Subclass as a view function with no computational content
r = Subclass(torch.zeros(2)) # r.grad_fn == None, r._is_view() == True
r = Subclass(torch.zeros(2, requires_grad=True)) # r.grad_fn == AliasBackward0, r._is_view() == True

# Treat Subclass as a view function with nontrivial computation content (custom autograd)
r = SubclassFunction.apply(torch.zeros(2, requires_grad=True)) # r.grad_fn == SubclassFunction, r._is_view() == True (but for now it will be false until https://github.com/pytorch/pytorch/issues/73604 is fixed)

# Create a new leaf variable that is disconnected from the previous graph (e.g., you want a parameter to be the subclass)
r = Subclass(torch.zeros(2), requires_grad=True)  # r._is_view() == False
r = Parameter(Subclass(torch.zeros(2)) # this works too, after https://github.com/pytorch/pytorch/issues/73604
# NB: Subclass(Parameter(torch.zeros(2)) doesn't work; we will interpret this as Subclass as a view function on the parameter!
# NB: for symmetry with Parameter, we will accept torch.zeros(2, requires_grad=True), but it is still detached
# NB: Subclass(torch.zeros(2), requires_grad=False)  # r._is_view() == False; aka `requires_grad` doesn't default to False; it defaults to None. For clarity, we could also consider supporting a separate kwarg `detach=True`. This is mostly niche, as if the torch.zeros(2) is not aliased from anywhere else, it doesn't really matter if we detach or not unless you post facto set `requires_grad=True`

@albanD
Copy link
Collaborator

albanD commented Mar 4, 2022

r = Subclass(torch.zeros(2), requires_grad=True) # r._is_view() == False

Why is the view False here?
In my mind, the require_grad-ness is independent from the differentiability (that propagates the view).

Or is the problem with the fact that views that are leafs are ~broken today?

@ezyang
Copy link
Contributor Author

ezyang commented Mar 4, 2022

In my mind, the require_grad-ness is independent from the differentiability (that propagates the view).

Under the hood, setting requires_grad causes a detach rather than an alias. Detach destroys view-ness info which is why it reports false. There is a little bit of indirectness here: I am assuming that if you said requires_grad, you want to define a leaf, and to define a leaf it cannot be a view.

@ezyang
Copy link
Contributor Author

ezyang commented Mar 30, 2022

OK so there is a problem with the strategy I suggested, which can be seen with this sample program using subclass zoo:

        x = TrivialTensorViaInheritance(torch.tensor(1.0))
        y = x * torch.tensor(1.0, requires_grad=True)
        y.relu_()

failing with

RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

It shouldn't. The problem is that inside __torch_dispatch__, when we allocate a new TrivialTensorViaInheritance using the default strategy this results in an alias() tensor with metadata setup for the view (even though we're inside of a "no grad" context, e.g., after autograd.) This gets propagated to the final return and so y looks like a view, even though all that really happened was we wanted to take the original tensor and wrap it into a subclass (without actually doing any view tracking). The error message is slightly wrong here, y is not actually a leaf.

@ezyang
Copy link
Contributor Author

ezyang commented Mar 31, 2022

My current thinking is that when we are "past" the autograd layer (as is the case with __torch_dispatch__, all view tracking should be disabled. We don't currently have a TLS way of doing that, but it wouldn't be too hard to add it. View information should then be restored by the AutogradViewOrInplace key; in particular if you return a view this is not an error but it won't be tracked unless AutogradViewOrInplace sets it up.

Something that doesn't work is to forbid torch dispatch from returning views; we can see from the default constructor for Tensor that the most obvious way of creating a tensor subclass involves creating a view from a temporary tensor that immediately ceases to exist (so you shouldn't think of it as a view).

One thing that is annoying is if someone incorrectly returns a real view (as opposed to a fake view derived from a temporary tensor) we will lose the view information. Not sure if there is an easy way to detect this has occurred.

@albanD
Copy link
Collaborator

albanD commented Mar 31, 2022

My current thinking is that when we are "past" the autograd layer (as is the case with torch_dispatch, all view tracking should be disabled.

It is today. It only get restored alongside the Autograd key by PythonTLSSnapshot. Which is what you want as the autograd is re-enabled there.

nit:for "base_tensor.py": the __init__() method should have the same signature as __new__ otherwise the extra args cannot be used.

Changing the "wrap" of the TrivialTensorViaInheritance to do return cls(t, requires_grad=False) does solve the problem. Is that a satisfying solution for you?

@ezyang
Copy link
Contributor Author

ezyang commented Mar 31, 2022

It is today. It only get restored alongside the Autograd key by PythonTLSSnapshot. Which is what you want as the autograd is re-enabled there.

Hmm, this doesn't seem like a complete explanation to me. The problem here is that there are two "levels" of autograd (one for the outer object which I got as inputs, and one for the inner objects I may be wrapping over), and I only want autograd to be reenabled for the inner objects.

With the current behavior, weird stuff like this can happen:

    def test_grad_fn(self):
        class TestTensor(BaseTensor):
            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                if func is torch.ops.aten.add.Tensor and 'alpha' in kwargs:
                    # decompose it
                    r = torch.add(args[0], args[1] * kwargs['alpha'])
                    self.assertIsNone(r.grad_fn)
                    return r
                return super().__torch_dispatch__(func, types, args, kwargs)
        
        x = TestTensor(torch.tensor(1.0)).requires_grad_()
        y = TestTensor(torch.tensor(2.0)).requires_grad_()
        torch.add(x, y, alpha=2)

This test fails, which seems suboptimal.

@ezyang
Copy link
Contributor Author

ezyang commented Apr 1, 2022

So my current thinking is that a user should be explicitly responsible for restoring TLS. A few justifications for this:

  1. This is how it works for custom autograd function forwards; we don't detach the inputs, instead autograd is disabled and you have to reenable it if you want to fancy "reentrant" autograd stuff.
  2. We don't want to detach inputs before passing them into torch dispatch because this will result in a ton of unnecessary allocation, a really big performance footgun.
  3. There's no non-context manager API for users to explicitly say whether or not they want to operate in the "mode not restored" or "mode restored" context; we only provide torch functions (but see also maybe the API proposal Make it possible to skip only one hop of __torch_function__ override #55093 ; but providing this universally is a huge lift)

@albanD
Copy link
Collaborator

albanD commented Apr 1, 2022

From offline discussion, we agreed on updating the TLS system: #75130

I think we want to revisit this issue in that context afterwards.

@eellison
Copy link
Contributor

is this pr still necessary ?

@ezyang ezyang closed this May 20, 2022
@ezyang
Copy link
Contributor Author

ezyang commented May 20, 2022

No it's wrong so we won't do it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants