[Prototype] Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time#89392
[Prototype] Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time#89392voznesenskym wants to merge 27 commits intomasterfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89392
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5321f59: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if not isinstance(x, torch.Tensor): | ||
| return x | ||
| if isinstance(x, torch._subclasses.fake_tensor.FakeTensor): | ||
| return x |
There was a problem hiding this comment.
This looks questionable. If an argument is already a fake tensor, it's unlikely to be consistent with the freshly allocated fake mode. Which means you'd probably get an error if you tried to actually use it. Better to not support this case.
There was a problem hiding this comment.
Fwiw this is copied fro the old functionality, but let's see if we can make things better than we found them.
There was a problem hiding this comment.
This isn't copied. The old code doesn't explicitly test for FakeTensor.
functorch/_src/aot_autograd.py
Outdated
| def fakify_params_and_buffers(flat_args): | ||
| nonlocal fake_mode | ||
| if config.use_fake_tensor: | ||
| flat_inputs, _ = pytree.tree_flatten(inputs) |
There was a problem hiding this comment.
A comment here saying what this is doing will help other readers
functorch/_src/aot_autograd.py
Outdated
| def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: | ||
| def aot_module_simplified( | ||
| mod: nn.Module, | ||
| inputs, |
There was a problem hiding this comment.
A comment saying what the acceptable inputs here would be good.
…sors" Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
…ograd, move aot_autograd compilation to lowering time" After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
…arguments directly" This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 49f39c5 Pull Request resolved: #89669
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
…sors" Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
… mode in aot_autograd, move aot_autograd compilation to lowering time" After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
…ograd, move aot_autograd compilation to lowering time" After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
There is only one call site for compiler_fn, so we can safely delay wrapping verify correctness to here. This will help later when we change the backend compiler calling convention to pass fake tensors (but I need to pass real tensors here.) This is adapted from voz's changes at #89392 but with less changes to the substantive logic. I only moved the relevant inner implementation; there are no changes otherwise. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: #89662 Approved by: https://github.com/voznesenskym
…arguments directly" This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
…sors" Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: #89656 Approved by: https://github.com/voznesenskym
…arguments directly" This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
…sors" Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
This is extracted from voz's #89392 Previously, the implementation did some half-assed caching where it returned a callable, that when invoked for the first time, actually performed the compilation. Delaying the compilation like this... seems totally unnecessary? To make matters worse, this has cost (we have to check if we hit the cache) and unsound (because the compiled function may not be valid for other arguments.) So instead, we ask user to provide arguments, and compile everything immediately. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… mode in aot_autograd, move aot_autograd compilation to lowering time" After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Taken from voz's #89392 Signed-off-by: Edward Z. Yang <ezyangfb.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire