Skip to content

[Prototype] Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time#89392

Closed
voznesenskym wants to merge 27 commits intomasterfrom
voz/plumb_attempt_again
Closed

[Prototype] Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time#89392
voznesenskym wants to merge 27 commits intomasterfrom
voz/plumb_attempt_again

Conversation

@voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Nov 20, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89392

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5321f59:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

if not isinstance(x, torch.Tensor):
return x
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor):
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks questionable. If an argument is already a fake tensor, it's unlikely to be consistent with the freshly allocated fake mode. Which means you'd probably get an error if you tried to actually use it. Better to not support this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw this is copied fro the old functionality, but let's see if we can make things better than we found them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't copied. The old code doesn't explicitly test for FakeTensor.

def fakify_params_and_buffers(flat_args):
nonlocal fake_mode
if config.use_fake_tensor:
flat_inputs, _ = pytree.tree_flatten(inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here saying what this is doing will help other readers

def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
def aot_module_simplified(
mod: nn.Module,
inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment saying what the acceptable inputs here would be good.

ezyang added a commit that referenced this pull request Nov 25, 2022
…sors"

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
…ograd, move aot_autograd compilation to lowering time"

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
…arguments directly"

This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 49f39c5
Pull Request resolved: #89669
ezyang added a commit that referenced this pull request Nov 25, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
…sors"

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
… mode in aot_autograd, move aot_autograd compilation to lowering time"

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
…ograd, move aot_autograd compilation to lowering time"

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
… aot_autograd compilation to lowering time

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 8684d1d
Pull Request resolved: #89672
pytorchmergebot pushed a commit that referenced this pull request Nov 25, 2022
There is only one call site for compiler_fn, so we can safely delay
wrapping verify correctness to here.  This will help later when we
change the backend compiler calling convention to pass fake tensors
(but I need to pass real tensors here.)

This is adapted from voz's changes at #89392
but with less changes to the substantive logic.  I only moved the relevant
inner implementation; there are no changes otherwise.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: #89662
Approved by: https://github.com/voznesenskym
ezyang added a commit that referenced this pull request Nov 25, 2022
…arguments directly"

This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
…sors"

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 25, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 26, 2022
… aot_autograd compilation to lowering time

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 8684d1d
Pull Request resolved: #89672
pytorchmergebot pushed a commit that referenced this pull request Nov 26, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: #89656
Approved by: https://github.com/voznesenskym
ezyang added a commit that referenced this pull request Nov 26, 2022
…arguments directly"

This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 26, 2022
…sors"

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 26, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Nov 26, 2022
This is extracted from voz's #89392

Previously, the implementation did some half-assed caching where it
returned a callable, that when invoked for the first time, actually
performed the compilation.  Delaying the compilation like this...
seems totally unnecessary?  To make matters worse, this has cost
(we have to check if we hit the cache) and unsound (because the
compiled function may not be valid for other arguments.)

So instead, we ask user to provide arguments, and compile everything
immediately.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Nov 26, 2022
… mode in aot_autograd, move aot_autograd compilation to lowering time"

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Nov 26, 2022
Taken from voz's #89392

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants