Skip to content

Immediately compile backwards graph in AOTAutograd if dynamic shapes#104971

Closed
ezyang wants to merge 8 commits intogh/ezyang/2220/basefrom
gh/ezyang/2220/head
Closed

Immediately compile backwards graph in AOTAutograd if dynamic shapes#104971
ezyang wants to merge 8 commits intogh/ezyang/2220/basefrom
gh/ezyang/2220/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 11, 2023

Stack from ghstack (oldest at bottom):

Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph. This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@ezyang ezyang requested a review from Chillee as a code owner July 11, 2023 14:26
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104971

Note: Links to docs will display an error until the docs builds have been completed.

✅ 3 Unrelated Failures

As of commit ada9501:

BROKEN TRUNK - The following job failed but were present on the merge base 8c479d3:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Jul 11, 2023
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 74ba604
Pull Request resolved: #104971
with track_graph_compiling(aot_config, "backward"):
placeholder_list = fx_placeholder_vals(bw_module)

compiled_bw_func = aot_config.bw_compiler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw - one thing that subclass support will (eventually, not immediately) need is backwards guards: when we generate the joint, we might incorrectly assume that the grad_outputs are/are_not subclasses, which would require us to re-trace and recompile the backward later (I'm writing a doc on subclass requirements, more details will be in the doc).

Compiling the backward eagerly is probably not optimal if we end up having to recompile the backward later, although maybe we're okay with this (since invalidating bw guards is hopefully rare).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Your thing is going to need true two level cache. But IMO you should just force your users to use compiled backwards in that case, which no longer has this problem. (BTW, @jansel's compiled autograd is what convinced me to do this "simpler" fix; basically if you have any complicated situation where we don't know ahead of time what the gradients will be, you instead use compiled autograd to be able to compile given full info.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, telling users to use compiled autograd when this happens sounds fair! (still need to eventually figure out how to teach compiled autograd how to add extra bw guards)

@ezyang ezyang requested a review from shunting314 July 11, 2023 15:04
@albanD albanD removed their request for review July 11, 2023 18:08
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang ezyang changed the title Immediately compile backwards graph in AOTAutograd Immediately compile backwards graph in AOTAutograd if dynamic shapes Jul 11, 2023
…mic shapes"


Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
…mic shapes"


Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
…mic shapes"


Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 12, 2023
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 13a23ce
Pull Request resolved: #104971
@ezyang
Copy link
Contributor Author

ezyang commented Jul 12, 2023

If there are no dynamic shapes, I restore the old behavior of lazy compilation to shut up some TorchScript failures. If eager backwards compilation fails, I suppress it, which gets us through some failures in our test suite where our backwards dynamic codegen doesn't actually work.

@ezyang ezyang added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Jul 12, 2023
Comment on lines +2803 to +2805
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea


def run_and_get_triton_code(fn, *args, **kwargs):
_, source_codes = run_and_get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we store a flag to drive if this should be exactly 1, or of (1, 2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very awkward, because the current implementation will attempt to compile backwards, and if backwards failed to compile, suppress the error and return anyway. So there isn't really a clear delineation.

)
m.eval()
self.common(m, (torch.randn([16, 32]),), check_lowp=False)
with torch.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

por que this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer norm's backward compilation doesn't work, so the no grad forces us not to attempt compile it

Comment on lines +2813 to +2816
# saved activations can have different stride to eager if
# the compiler does layout optimization. We should restride the
# tensor passed in for compiling the backward graph using the
# saved tensor's stride.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would stamp but I don't know the nuances of strides well enough.

…mic shapes"


Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 13, 2023
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 47eb2c4
Pull Request resolved: #104971
@ezyang
Copy link
Contributor Author

ezyang commented Jul 14, 2023

This is ready to go, just waiting for review.

Comment on lines +2858 to +2866
try:
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
)
except Exception:
log.warning(
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What necessitates this ? It would be nice to not land with a try-catch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per operator backwards dynamic codegen is still buggy af. Sample run: https://hud.pytorch.org/pytorch/pytorch/pull/104971?sha=b0bb7782a9eb8b85acbcb9ffe1835c07c6dc9f80

I'm not actually suppressing anything. If you actually try to compile backwards, we will try to compile again and THEN fail. If this suppression works out, it just means you didn't actually need the backwards graph at all.

Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, this is mostly just code movement?

# the compiler does layout optimization. We should restride the
# tensor passed in for compiling the backward graph using the
# saved tensor's stride.
for i in range(len(placeholder_list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this is all code movement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

@ezyang
Copy link
Contributor Author

ezyang commented Jul 17, 2023

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

…mic shapes"


Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/ezyang/2220/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/104971)

pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2023
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph.  This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.

However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately.
This should also make it easier to predict when compilation occurs,
since compilation now all happens up front during forwards.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 6101c9f
Pull Request resolved: #104971
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants