Skip to content

Introduce discontinuity to nested tensor#80981

Closed
YifanShenSZ wants to merge 24 commits intopytorch:masterfrom
YifanShenSZ:master
Closed

Introduce discontinuity to nested tensor#80981
YifanShenSZ wants to merge 24 commits intopytorch:masterfrom
YifanShenSZ:master

Conversation

@YifanShenSZ
Copy link
Contributor

@YifanShenSZ YifanShenSZ commented Jul 6, 2022

Nested tensor used to assume the buffer memory to be contiguous. However, some operations can break that assumption:

  • reshape
  • transpose
  • slice

To be able to access underlying tensors from discontinuous buffer, we need 3 metadata:

  • sizes of each tensor (nested_size_tensor_)
  • strides of each tensor (nested_stride_tensor_)
  • offset of each tensor (offsets_)

so we access each tensor by buffer.as_strided(size, stride, offset)

This pull request introduces the offsets metadata, then added reshape and transpose so that we can create discontinuous cases for testing. Unbind, select, dropout, softmax, bmm are refactored to provide tests.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 6, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit b765b86 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@YifanShenSZ YifanShenSZ marked this pull request as ready for review July 6, 2022 19:59
@YifanShenSZ YifanShenSZ marked this pull request as draft July 6, 2022 20:36
… bmm with offsets and strides for discontinuous buffer memory
@YifanShenSZ YifanShenSZ marked this pull request as ready for review July 7, 2022 17:49
@YifanShenSZ YifanShenSZ requested a review from bdhirsh as a code owner July 7, 2022 17:49
@YifanShenSZ YifanShenSZ changed the title Add nested tensor metadata offsets Introduce discontinuity to nested tensor Jul 7, 2022
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! There is quite a bit in this PR, but it wasn't as bad to review as I thought. For most of the operations, the refactors make perfect sense.

I think the reshape impl is a bit unwieldy atm so I suggested some slight refactoring below. IIRC it should be relatively simple to make the logic addressing the various cases clear. Additionally, I think the tests would benefit some from increased clarity as to what they're testing.

Overall, this looks like quite a bit of work, but I believe it's correct!

)

@torch.inference_mode()
def test_unbind_discontinuous(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm finding it a bit hard to understand what's going on in these tests.

testing suggestion: throughout these tests, call into a utility that produces a non-contiguous tensor from a given tensor. You can then run the contiguous and non-contiguous inputs through the op under test and verify output equivalence. wdyt about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea here is to compare nested tensor op with padded op:

  • create a nested tensor, mess up its contiguity, then unbind
  • pad the nested tensor, do same reshaping and transposing, then (kind of) unbind

Correct me if I'm wrong: are you suggesting comparing 2 nested tensors who appears to have same entries but in fact different memory layout?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong: are you suggesting comparing 2 nested tensors who appears to have same entries but in fact different memory layout?

Right, we'd expect that if we run the same operation on each of these, their outputs would also appear to have the same entries (and possibly different memory layout, but this depends on the op). Existing tests cover nested contiguous vs. padded, and here the focus would be on nested contiguous vs. nested non-contiguous. wdyt, are we losing any coverage this way?

@YifanShenSZ YifanShenSZ marked this pull request as draft July 8, 2022 19:58
@YifanShenSZ
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/master pull/80981/head returned non-zero exit code 1

Rebasing (1/10)
Auto-merging aten/src/ATen/native/native_functions.yaml
Auto-merging aten/src/ATen/native/nested/NestedTensorMath.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/nested/NestedTensorMath.cpp
Auto-merging test/test_nestedtensor.py
CONFLICT (content): Merge conflict in test/test_nestedtensor.py
error: could not apply f41be74b40... support nested_tensor * scalar
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply f41be74b40... support nested_tensor * scalar

Raised by https://github.com/pytorch/pytorch/actions/runs/2651942641

@YifanShenSZ YifanShenSZ marked this pull request as ready for review July 12, 2022 18:04
@facebook-github-bot
Copy link
Contributor

@YifanShenSZ has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

device_guard: False
dispatch:
CompositeImplicitAutograd: reshape
NestedTensorCPU, NestedTensorCUDA: reshape_nested
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda doubt this is related to the inference_mode() issue, but this looks a bit suspicious. Do we want NestedTensor::reshape to be able to work with autograd? If we do, some quick comments:

The way that normal reshape() works today is that it sometimes returns a of the input, and sometimes doesn't depending on the arguments to reshape. We can't actually support an autograd formula for that directly, so what happens instead of that reshape() decomposes into a view op or a clone, and autograd handles those individual ops instead. Does the nested tensor implementation of reshape() also "sometimes return a view"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Yes that's correct, AFAICT the NT implementation of reshape uses the same underlying logic to determine when to return a view and when to return a copy. Do you see it being possible to define a NT-specific reshape_backward that deals with this properly?

Note that we need a NT-specific version of reshape because the semantics of -1s are slightly different due to the possibility for ragged dims.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested tensor implementation can return a view but I guess that's not the "view" autograd engine uses:

  • I heard that autograd would handle "view" differently as "fast path"
  • I also heard that this "fast path" is disabled for nested tensor: it would record nested tensor view just as another function
  • So based on sharing memory it can return a view, but based on "fast path" it's not an autograd view I think

Could that trigger some edge case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Yes that's correct, AFAICT the NT implementation of reshape uses the same underlying logic to determine when to return a view and when to return a copy. Do you see it being possible to define a NT-specific reshape_backward that deals with this properly?

Note that we need a NT-specific version of reshape because the semantics of -1s are slightly different due to the possibility for ragged dims.

Yes the way I implemented reshape_backward is just like Driss's way for linear. Although I can also do the ugly "is_nested()" version to maintain compositeimplicit-ness

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the way I implemented reshape_backward is just like Driss's way for linear.

This is good, but I'm curious how this backward deals with the forward sometimes returning a view and sometimes returning a copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the way I implemented reshape_backward is just like Driss's way for linear.

This is good, but I'm curious how this backward deals with the forward sometimes returning a view and sometimes returning a copy?

Basically no special treatment: simply reshape grad as input

@YifanShenSZ
Copy link
Contributor Author

YifanShenSZ commented Jul 29, 2022

minimal reproduce of a CI-not-covered error case:

import torch

@torch.inference_mode()
def unary(x):
    y = x.reshape(2, -1, 2, 3)
    return y

# with torch.inference_mode():
if __name__ == "__main__":
    x = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))])
    y = unary(x)
    print(y)

would raise "RuntimeError: Cannot set version_counter for inference tensor". If release with torch.inference_mode(): from comment, then it runs properly.

The issue was found to be reshape specific, i.e. other ops such as dropout and softmax are fine. This probably comes from the special autograd logic for reshape: it has to be kept as CompositeImplicit, without adding any other backend.

Successfully fixed this issue by keeping reshape CompositeImplicit. Commits are comming...

@YifanShenSZ
Copy link
Contributor Author

Thanks @drisspg for trying to reproduce the exact error. Quoting his comment below

1.) Run repro script on fbcode/stable -> successfully runs
2.) Run Script on just my changes -> successfully runs
3. Run script on Yifan's reshape changes(making it explicitly composite implicit) and my changes(I was lazyI -> script fails
4.) Apply the the patch to reshape to make it not explicitly composite_implicit on top of mine and Yifan's changes -> successfully runs

@YifanShenSZ YifanShenSZ marked this pull request as ready for review July 29, 2022 23:57
@YifanShenSZ
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Aug 1, 2022
Summary:
Nested tensor used to assume the buffer memory to be contiguous. However, some operations can break that assumption:
* reshape
* transpose
* slice

To be able to access underlying tensors from discontinuous buffer, we need 3 metadata:
* sizes of each tensor (`nested_size_tensor_`)
* strides of each tensor (`nested_stride_tensor_`)
* offset of each tensor (`offsets_`)

so we access each tensor by `buffer.as_strided(size, stride, offset)`

This pull request introduces the offsets metadata, then added reshape and transpose so that we can create discontinuous cases for testing. Unbind, select, dropout, softmax, bmm are refactored to provide tests.

Pull Request resolved: #80981
Approved by: https://github.com/jbschlosser

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5f9939f65eea8b5ea017fdd10668f48364d6c0b1

Reviewed By: osalpekar

Differential Revision: D38306864

Pulled By: YifanShenSZ

fbshipit-source-id: 7870506ff619f5697e1add2be0b9200265e9c320
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YifanShenSZ this implementation is silently wrong. I think you will want to do major rework of this PR. Maybe it would even be safer to revert it? What do you think?

device_check: NoCheck
device_guard: False

- func: _reshape_nested(Tensor self, int[] shape) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this schema different from the reshape above?
This function is also return sometimes a view and sometimes not a view.

So this schema is wrong (and autograd is going to generate silently wrong code for it!)

It is NOT possible to register autograd for a function that has such a property so you will need to keep reshape CompositImplicitAutograd.
Since the computeStride function there will not work with nested, you can simply create a new native function to abstract it away: You can make it a native function that takes self as input directly and you will be able to provide a custom Nested implementation for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this wrong schema is also most likely the reason for the inference mode issue you had as the autograd kernel is not properly registered!

// See Note [NestedTensor Not Included in Backend Keys]
// The caveat to that note is that nested_tensor is a special case
// where we would like to support composite implict kernels but not
// explicit kernels therefore we manually add the key to the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand? Why can't you just override the explicit kernels you don't want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Driss's work that we thought would be a necessary fix here. The original discussion is here

@drisspg @bdhirsh would know better than I do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added in order to allow for composite implicit kernels to work with nested_tensors. Some functions like softmax are composite implicit functions that in turn call non composite ops like _softmax . In order to be able to register at _softmax the nested tensor needs to be able to run by composite function. Similiarly Reshape as an example of one of these ops whose top most function is composite implicit that quickly calls a non-composite op

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that you want it to get the Implicit kernels, but this comment seem to indicate that you do special logic (compared to the other classical backend keys) to prevent it from getting the Explicit kernels. Did I read this comment wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is correct, work for implicit don't work for explicit. I went back and fourth about this, practically it appears that there are alot more composite explicit keys that end up calling sizes and producing strange hard to debug errors for users. This possibility exists for composite implicit as well, however was happening with less frequency for those kernels.

Is your suggestion to enable for both composite kernels types and whenever we find a composite explicit doing something wrong just register a Torch_check(false) kernel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your suggestion to enable for both composite kernels types and whenever we find a composite explicit doing something wrong just register a Torch_check(false) kernel?

That would have been my first guess yes.
But I am not sure:

  • how involved is the current "workaround" to get Implicit but not Explicit kernels? cc @bdhirsh
  • What happens when you want to use some of the Explicit kernels? You now have to copy/paste them?

// everything is fine
// * but if we create the input tensor not with inference mode,
// then errors like "Cannot set version_counter for inference tensor" arise
if (self.is_nested()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing such a check is usually a code smell. You use the dispatcher to call reshape so if you had the right dispatch for it, you should have ended up in your implementation directly.

@YifanShenSZ
Copy link
Contributor Author

This "fix" works fine currently and passes all the tests, but in order to support autograd, a follow-up pr is coming. Thanks to the advices from @albanD

drisspg referenced this pull request in drisspg/pytorch Aug 31, 2022
Summary:
Pull Request resolved: pytorch#84154

Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior.

This pull request fixes it by:
1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd`
2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor`

Side changes:
* add contiguous memory format support to `clone_nested`
* add `view_nested`
* add `reshape_as_nested`

Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041)

Pull Request resolved: pytorch#82754

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V26/executorch/)|

|**Modified Pages**|

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V17/executorch/)|

|**Modified Pages**|

Reviewed By: albanD, bdhirsh

Differential Revision: D39023822

Pulled By: drisspg

fbshipit-source-id: 87acd4f9fb61cd094fccad7801c25e2a1bfed88b
pytorchmergebot referenced this pull request Sep 1, 2022
The original author is @YifanShenSZ  and the original PR is: #82754
# Summary:
Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior.

This pull request fixes it by:
1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd`
2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor`

Side changes:
* add contiguous memory format support to `clone_nested`
* add `view_nested`
* add `reshape_as_nested`

Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041)

Pull Request resolved: #82754

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)|

|**Modified Pages**|

Reviewed By: albanD

Differential Revision: D39023822

Pulled By: drisspg

Pull Request resolved: #84154
Approved by: https://github.com/bdhirsh, https://github.com/albanD
facebook-github-bot referenced this pull request Sep 1, 2022
Summary:
Pull Request resolved: #84154

Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior.

This pull request fixes it by:
1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd`
2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor`

Side changes:
* add contiguous memory format support to `clone_nested`
* add `view_nested`
* add `reshape_as_nested`

Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041)

Pull Request resolved: #82754

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V27/executorch/)|

|**Modified Pages**|

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V17/executorch/)|

|**Modified Pages**|

Reviewed By: albanD, bdhirsh

Differential Revision: D39023822

Pulled By: drisspg

fbshipit-source-id: 872c81dc847d280366ef9f187f9b9bcb06aac73f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged release notes: nested tensor Changes that have a direct impact on nested tensors Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants