Skip to content

Read out real strides from compilation result, rather than real args#105010

Closed
ezyang wants to merge 3 commits intogh/ezyang/2222/basefrom
gh/ezyang/2222/head
Closed

Read out real strides from compilation result, rather than real args#105010
ezyang wants to merge 3 commits intogh/ezyang/2222/basefrom
gh/ezyang/2222/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 11, 2023

Stack from ghstack (oldest at bottom):

This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way. Straight from the compiler's mouth will do it.

I can't easily get the information via the return result of fw_compiler without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8

This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way.  Straight from the compiler's mouth will do it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105010

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e2a8ecd:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# compiler to aot_autograd
# Per output, what the compiler specified stride of the output is,
# or None if no stride is known
self.output_strides: Optional[List[Optional[List[int]]]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we do layout optimization only if dynamic shape is disabled. But we are working on resolving that (blocked on the split reduction support). So, should we set the type here to consider SymInt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this can have symints, will amend

# Return the output strides to the caller via TracingContext
assert len(context.output_strides) == 0
for out in graph.graph_outputs:
if hasattr(out, "layout"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we see cases that out does not have a layout attribute? Does this happen for some SymInt returned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, can have symint return. Maybe there is something better than hasattr to do here, but a lot of inductor code is written this way. Hard to be more precise without typing 👹

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, turns out I can't, because strides in inductor are sympy.Symbol, not SymInt. So I am just going ahead and storing the hints here only. When you fix this to do permutations instead, probably can make this a little less fragile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can not do permutations instead because of some non 'dense' activations? This happens for some real models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few ways we can do this. If we only ever change layout on dense tensors, we can make output_strides be None unless we changed the stride. Then the permutation is always defined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only ever change layout on dense tensors, ...

I think this can not be guaranteed in inductor right now. A non dense tensor's layout may get changes because of the layout change of its upstream tensors. And inductor does not force eager stride for that non dense tensor because of code here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/graph.py#L688 . I guess the algorithm will be tricky to restride a non-dense tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK will need to think this through carefully. I'll start a doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon further reflection, all of this is moot if we have compiled backwards, so let's not touch it unless it's causing someone problems.

if tc is None:
yield None
return
old_output_strides = tc.output_strides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to raise a concern about nested compiling where we may need a stack for output_strides. But I can not come up with a realistic example. So probably the current implementation is good enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicitly is a stack via old_output_strides!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, right, that's the whole point of context manager after all... haha

continue

# Comparing ph_arg.stride() with real_arg.stride() directly may
if forward_saved_for_backwards_strides is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this check out of the for loop may be slightly faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to save myself an indent 😂 This is only during compilation so I don't think it matters much

if real_stride is None:
continue

assert _get_hints(real_stride) == all_args[i].stride(), f"{real_stride} {all_args[i].stride()}"
Copy link
Contributor

@shunting314 shunting314 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you will remove this soon since the main point of the change is we won't have access to all_args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, this is just for CI here.

Comment on lines +3009 to +3011
forward_saved_for_backwards_strides = fwd_output_strides[
CompiledFunction.metadata.tensors_saved_for_backwards_slice
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice!

… real args"


This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way.  Straight from the compiler's mouth will do it.

I can't easily get the information via the return result of `fw_compiler` without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8

[ghstack-poisoned]
… real args"


This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way.  Straight from the compiler's mouth will do it.

I can't easily get the information via the return result of `fw_compiler` without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 12, 2023
This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way.  Straight from the compiler's mouth will do it.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: f9ceb24
Pull Request resolved: #105010
@ezyang
Copy link
Contributor Author

ezyang commented Jul 12, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ezyang/2222/head branch July 15, 2023 14:16
pytorchmergebot pushed a commit that referenced this pull request Jul 29, 2023
…05251)

Currently all information about the dependencies of ghstack PRs (e.g. #105010) is stripped away:
https://github.com/pytorch/pytorch/blob/c984885809194e0a807b3f5543450fae4dfa841a/.github/scripts/trymerge.py#L1077-L1078

This PR adds this information back in a more compact form. All dependencies (PR numbers) of each PR in ghstack are recorded.

The resulting commit message will look like this (the last line is new):

> Mock title (#123)
>
> Mock body text
> Pull Request resolved: #123
> Approved by: https://github.com/Approver1, https://github.com/Approver2
> ghstack dependencies: #1, #2

---

### Testing

Unit tests.

---

### Note Re: `# type: ignore[assignment]` in unit tests.

I did my due diligence to find alternatives. Unfortunately mypy [doesn't](python/mypy#6713) support this [way of patching methods](https://docs.python.org/3/library/unittest.mock-examples.html#mock-patching-methods), and the alternatives are either extremely verbose or don't work for this case. I decided it's not worth the effort (since the problem is limited only to the unit test).
Pull Request resolved: #105251
Approved by: https://github.com/huydhn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants