Skip to content

functionalization: skip meta reference compute for aot autograd#87108

Closed
bdhirsh wants to merge 4 commits intogh/bdhirsh/328/basefrom
gh/bdhirsh/328/head
Closed

functionalization: skip meta reference compute for aot autograd#87108
bdhirsh wants to merge 4 commits intogh/bdhirsh/328/basefrom
gh/bdhirsh/328/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Oct 17, 2022

The context is that historically, XLA/LTC tensors haven't had accurate stride information, and functionalization would run "reference" meta kernels for view ops on the side to properly compute strides.

This is more complicated in symint tracing world - we have a FunctionalTensorWrapper() that wraps the underlying tensor and has its own set of sizes/strides metadata, but we never create proxy objects for the sizes/strides of the wrapper.

In symint tracing world with aot autograd, we're guaranteed that our underlying strides are accurate anyway, since aot autograd uses fake tensors to perform tracing. We encountered a few bugs with symint's from the FunctionalTensorWrapper making their way into __torch_dispatch__. To side-step that area of bugs completely (and marginally improve perf), this PR disables the meta tensor tracing for non XLA/LTC use cases.

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87108

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit f24231a:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 17, 2022
bdhirsh added a commit that referenced this pull request Oct 17, 2022
@albanD albanD removed their request for review October 17, 2022 19:21
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 17, 2022

@ezyang I'm seeing some interesting UBSAN errors, that seem to have been uncovered around SymIntNodeImpl (unrelated to functionalization) - I've been staring at them for a while but I'm wondering if you have any insight (logs)

The error is:

/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:267:45: runtime error: upcast of misaligned address 0x1ebebebebebebebe for type 'c10::SymIntNodeImpl', which requires 8 byte alignment
0x1ebebebebebebebe: note: pointer points here
<memory cannot be printed>
    #0 0x7fa87c57392d in c10::intrusive_ptr<c10::SymIntNodeImpl, c10::detail::intrusive_target_default_null_type<c10::SymIntNodeImpl> >::retain_() (/opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so+0x12292d)
    #1 0x7fa87c569ca7 in c10::SymInt::toSymIntNodeImpl() const (/opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so+0x118ca7)
    #2 0x7fa889ed1b90 in c10::SymInt::SymInt(c10::SymInt const&) (/opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so+0xd71cb90)

Some googling around led me here, and I was wondering if something funky is going on around our subclassing of intrusive_ptr, but I'm not too sure yet.

@ezyang
Copy link
Contributor

ezyang commented Oct 18, 2022

The bebebe pattern in the address suggests that you're reading out of some garbage memory that we memset to 0xBE but I don't actually see anywhere in our codebase where we use the 0xBE bit pattern.

Are you sure this isn't functionalization related? I feel that the aotdispatch tests will be calling functionalization?

You should put your heads together with @anjali411, @albanD were investigating another memory error and I wonder if this is the same root cause.

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 18, 2022

I don't think the ASAN failure is actually related to this change (original hypothesis from Horace's previous PR was that this PR might fix the skipped test), so I'm going to leave the test skipped and land this PR in the interest of fixing a few dynamic shapes models

…ograd"

The context is that historically, XLA/LTC tensors haven't had accurate stride information, and functionalization would run "reference" meta kernels for view ops on the side to properly compute strides.

This is more complicated in symint tracing world - we have a `FunctionalTensorWrapper()` that wraps the underlying tensor and has its own set of sizes/strides metadata, but we never create proxy objects for the sizes/strides of the wrapper.

In symint tracing world with aot autograd, we're guaranteed that our underlying strides are accurate anyway, since aot autograd uses fake tensors to perform tracing. We encountered a few bugs with symint's from the `FunctionalTensorWrapper` making their way into `__torch_dispatch__`. To side-step that area of bugs completely (and marginally improve perf), this PR disables the meta tensor tracing for non XLA/LTC use cases.




[ghstack-poisoned]
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing!

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 18, 2022

Are you sure this isn't functionalization related? I feel that the aotdispatch tests will be calling functionalization?

Oh the segfault shows up in test_proxy_tensor.py (Horace added a skip for a segfaulting batch_norm test), which is why I think it's unrelated to functionalization.

…ograd"

The context is that historically, XLA/LTC tensors haven't had accurate stride information, and functionalization would run "reference" meta kernels for view ops on the side to properly compute strides.

This is more complicated in symint tracing world - we have a `FunctionalTensorWrapper()` that wraps the underlying tensor and has its own set of sizes/strides metadata, but we never create proxy objects for the sizes/strides of the wrapper.

In symint tracing world with aot autograd, we're guaranteed that our underlying strides are accurate anyway, since aot autograd uses fake tensors to perform tracing. We encountered a few bugs with symint's from the `FunctionalTensorWrapper` making their way into `__torch_dispatch__`. To side-step that area of bugs completely (and marginally improve perf), this PR disables the meta tensor tracing for non XLA/LTC use cases.




[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 19, 2022

@pytorchbot merge -f "unrelated failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants