Make aot_module_simplified accept fake tensors#89670
Make aot_module_simplified accept fake tensors#89670ezyang wants to merge 9 commits intogh/ezyang/1589/basefrom
Conversation
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89670
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 1cc5577: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
| for x in flat_args: | ||
| if isinstance(x, FakeTensor): | ||
| fake_mode = x.fake_mode | ||
| break |
There was a problem hiding this comment.
I have a util I wrote for this, in the parent PR
def fake_mode_from_tensors(inputs: List[Any]):
"""
Takes a list of anything, unflattened is fine, returns a fake_mode
if any are fake. All fake modes on all fake tensors must be identical.
Returns None if no fake_mode is fine
"""
flat_inputs, _ = tree_flatten(inputs)
fake_mode = None
for flat_input in flat_inputs:
if isinstance(flat_input, torch._subclasses.FakeTensor):
if fake_mode is None:
fake_mode = flat_input.fake_mode
else:
assert fake_mode == flat_input.fake_mode
return fake_mode
Feel free to replace the continue-and-check-all-same-fake-mode assert w/ a break as you have it, but I do thing this should be in utils
There was a problem hiding this comment.
I didn't use this utility because flat_args is guaranteed to be a flat list of tensors here, whereas the utility does a flatten first. Better not to use pytree if you don't need it.
| if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode): | ||
| def convert(idx, x): | ||
| if not isinstance(x, torch.Tensor): | ||
| return x | ||
| if isinstance(x, FakeTensor): | ||
| assert x.fake_mode is fake_mode | ||
| return x | ||
| if idx < aot_config.num_params_buffers and config.static_weight_shapes: | ||
| return fake_mode.from_tensor(x, static_shapes=True) | ||
| return fake_mode.from_tensor(x, static_shapes=False) |
There was a problem hiding this comment.
Ive gone back and forth on the signal here.
- Dynamo sets
config.use_fake_tensorcorrectly - We check
config.use_fake_tensorin other place - You should never be in a state where you are in
config.use_fake_tensor == Falseand getting a fake mode
So maybe instead we either:
A) Assert that we are in fake config mode if we see a fake tensor in inputs on L1495 and use the presence of fake_mode to be analogous to config.use_fake_tensor (since you can only get one via inputs in config.use_fake_tensor or by making your own on L1500, also within config)
B) Or drop the isinstance(fake_mode) checks and assume that if we have the config set, we have the fake mode, and guard that with an assert.
There was a problem hiding this comment.
I don't think this is worth litigating much, because I think we should delete this config option (e.g., #89663 ).
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
Chillee
left a comment
There was a problem hiding this comment.
Is this just a temporary PR that will no longer be needed once we do fakification at the Dynamo level? (or at least prior to create_aot_dispatcher_function?).
I thought the contract with create_aot_dispatcher_function was going to be that it was passed a list of fake tensors.
|
@Chillee argument fakeification can be deleted once we pass in fake tensors from dynamo. But we are always on the hook for fakeifiying parameters. |
Strategy taken from voz's pytorch#89392 but my implementation strategy is a bit different. If a fake tensor is provided, we use its FakeTensorMode (and more importantly, its ShapeEnv--this is what is tested in the new unit test). Only one tensor needs to be fake; if nothing is fake we just make a fresh mode as before. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: pytorch#89670 Approved by: https://github.com/voznesenskym
Stack from ghstack (oldest at bottom):
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.
If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test). Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.
Signed-off-by: Edward Z. Yang ezyang@fb.com