Skip to content

Make aot_module_simplified accept fake tensors#89670

Closed
ezyang wants to merge 9 commits intogh/ezyang/1589/basefrom
gh/ezyang/1589/head
Closed

Make aot_module_simplified accept fake tensors#89670
ezyang wants to merge 9 commits intogh/ezyang/1589/basefrom
gh/ezyang/1589/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Nov 25, 2022

Stack from ghstack (oldest at bottom):

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test). Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang ezyang@fb.com

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 25, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89670

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 1cc5577:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
@ezyang ezyang added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Nov 25, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Comment on lines +1494 to +1497
for x in flat_args:
if isinstance(x, FakeTensor):
fake_mode = x.fake_mode
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a util I wrote for this, in the parent PR

def fake_mode_from_tensors(inputs: List[Any]):
    """
    Takes a list of anything, unflattened is fine, returns a fake_mode
    if any are fake. All fake modes on all fake tensors must be identical.
    Returns None if no fake_mode is fine
    """
    flat_inputs, _ = tree_flatten(inputs)
    fake_mode = None
    for flat_input in flat_inputs:
        if isinstance(flat_input, torch._subclasses.FakeTensor):
            if fake_mode is None:
                fake_mode = flat_input.fake_mode
            else:
                assert fake_mode == flat_input.fake_mode
    return fake_mode

Feel free to replace the continue-and-check-all-same-fake-mode assert w/ a break as you have it, but I do thing this should be in utils

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use this utility because flat_args is guaranteed to be a flat list of tensors here, whereas the utility does a flatten first. Better not to use pytree if you don't need it.

Comment on lines +1508 to 1517
if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
def convert(idx, x):
if not isinstance(x, torch.Tensor):
return x
if isinstance(x, FakeTensor):
assert x.fake_mode is fake_mode
return x
if idx < aot_config.num_params_buffers and config.static_weight_shapes:
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(x, static_shapes=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ive gone back and forth on the signal here.

  1. Dynamo sets config.use_fake_tensor correctly
  2. We check config.use_fake_tensor in other place
  3. You should never be in a state where you are in config.use_fake_tensor == False and getting a fake mode

So maybe instead we either:

A) Assert that we are in fake config mode if we see a fake tensor in inputs on L1495 and use the presence of fake_mode to be analogous to config.use_fake_tensor (since you can only get one via inputs in config.use_fake_tensor or by making your own on L1500, also within config)

B) Or drop the isinstance(fake_mode) checks and assume that if we have the config set, we have the fake mode, and guard that with an assert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is worth litigating much, because I think we should delete this config option (e.g., #89663 ).

Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
@voznesenskym voznesenskym mentioned this pull request Nov 26, 2022
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Strategy taken from voz's #89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a temporary PR that will no longer be needed once we do fakification at the Dynamo level? (or at least prior to create_aot_dispatcher_function?).

I thought the contract with create_aot_dispatcher_function was going to be that it was passed a list of fake tensors.

@ezyang
Copy link
Contributor Author

ezyang commented Nov 28, 2022

@Chillee argument fakeification can be deleted once we pass in fake tensors from dynamo. But we are always on the hook for fakeifiying parameters.

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Strategy taken from voz's pytorch#89392 but my implementation strategy
is a bit different.

If a fake tensor is provided, we use its FakeTensorMode
(and more importantly, its ShapeEnv--this is what is tested
in the new unit test).  Only one tensor needs to be fake;
if nothing is fake we just make a fresh mode as before.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#89670
Approved by: https://github.com/voznesenskym
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1589/head branch June 8, 2023 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request release notes: torch.func release notes category for torch.vmap or torch.func.* APIs topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants