Skip to content

Fix tensor subclass + dynamic shapes in torch.compile + aot autograd#125941

Closed
guilhermeleobas wants to merge 75 commits intogh/guilhermeleobas/48/basefrom
gh/guilhermeleobas/48/head
Closed

Fix tensor subclass + dynamic shapes in torch.compile + aot autograd#125941
guilhermeleobas wants to merge 75 commits intogh/guilhermeleobas/48/basefrom
gh/guilhermeleobas/48/head

Conversation

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125941

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c9db0c5 with merge base 3b0f393 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guilhermeleobas added a commit that referenced this pull request May 10, 2024
@guilhermeleobas
Copy link
Collaborator Author

guilhermeleobas commented May 10, 2024

Pretty much a work in progress. I just want to see what is currently breaking.

Fixes issue: #124619

Changes

This PR addresses a bug in tensor subclasses and symbolic execution.
For each subclass, it appends the sizes to the list of arguments and
returns the computed shapes at runtime.

Most of the changes are in the unwrap_tensor_subclasses function. It
takes two extra flags: append_extra and is_runtime. While tracing, if
append_extra is true and we are tracing for the forward graph, extra arguments
are added.

An extra field (flat_tensor_extra_sizes_offset) is introduced to SubclassCreationMeta.
This field stores the offset from right to left for the sizes associated with a
tensor subclass. To compute the sizes at runtime, we can use #args[#args - offset : #args - offset + #sizes],
where offset is the extra field and #sizes is the number of sizes for the given subclass.

Test plan

Add tests for two different subclasses: TwoTensor and DoubleTensor. The
latter is a wrapper that behaves as if the inner tensor were twice its
original size.

The set of tests is composed of functions that return a mix of subclasses
and plain tensors.

[ghstack-poisoned]
[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request May 16, 2024
@guilhermeleobas guilhermeleobas requested a review from bdhirsh May 16, 2024 16:30
@guilhermeleobas guilhermeleobas marked this pull request as ready for review May 16, 2024 16:30
@ezyang
Copy link
Contributor

ezyang commented May 17, 2024

What exactly is the algorithmic strategy here?

[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request May 23, 2024
@guilhermeleobas guilhermeleobas marked this pull request as draft May 23, 2024 13:30
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label May 28, 2024
[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request May 29, 2024
[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request May 30, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator

bdhirsh commented Oct 24, 2024

@guilhermeleobas the new recursive size/stride handling looks mostly good to me. I was still a bit worried about edge cases, so I tried a few tests locally and I got a failure involving nested subclasses:

(I wanted to stress test a bit so I used a version of TwoTensor that relaxes the constraint that the shapes of the two inner tensors are the same)

import torch
import torch.utils._pytree as pytree


# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, a, b, outer_size=None, outer_stride=None):
        if outer_size is None:
            outer_size = a.size()
        if outer_stride is None:
            outer_stride = a.stride()

        assert (
            a.device == b.device
            and a.layout == b.layout
            and a.requires_grad == b.requires_grad
            and a.dtype == b.dtype
        )
        shape = outer_size
        kwargs = {}
        kwargs["strides"] = outer_stride
        kwargs["storage_offset"] = a.storage_offset()
        kwargs["device"] = a.device
        kwargs["layout"] = a.layout
        kwargs["requires_grad"] = a.requires_grad
        kwargs["dtype"] = a.dtype
        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

        return out

    def __init__(self, a, b, outer_size=None, outer_stride=None):
        self.a = a
        self.b = b

    def __repr__(self):
        a_repr = repr(self.a)
        b_repr = repr(self.b)
        return f"TwoTensor({a_repr}, {b_repr})"

    def __tensor_flatten__(self):
        return ["a", "b"], None

    @staticmethod
    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
        assert meta is None
        a, b = inner_tensors["a"], inner_tensors["b"]
        if type(a) is torch.Tensor:
            assert outer_size is not None
            assert outer_stride is not None
        return TwoTensor(a, b, outer_size, outer_stride)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        if kwargs is None:
            kwargs = {}
        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)

        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)

        out_a = func(*args_a, **kwargs_a)
        out_b = func(*args_b, **kwargs_b)
        out_a_flat, spec = pytree.tree_flatten(out_a)
        out_b_flat = pytree.tree_leaves(out_b)
        # for aten ops that return non-tensors, just assume that
        # our two inner tensors return the same value
        out_flat = [
            TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
            for o_a, o_b in zip(out_a_flat, out_b_flat)
        ]
        out = pytree.tree_unflatten(out_flat, spec)
        return out

@torch.compile(dynamic=True)
def f(x, y):
    tmp1 = x.sin()
    tmp2 = y.sin()
    return tmp1.sum(), tmp2.sum()


x = TwoTensor(
    TwoTensor(
        torch.randn(3, 4),
        torch.randn(5, 6, 7),
    ),
    TwoTensor(
        torch.randn(4),
        torch.randn(2, 3),
    )
)

y = TwoTensor(
    torch.randn(2, 3, 4, 5),
    TwoTensor(
        torch.randn(3, 4),
        torch.randn(5),
    )
)

out = f(x, y)

The fails for me, with:

  File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in inner_fn
    unwrapped_args = runtime_unwrap_tensor_subclasses(
  File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 239, in runtime_unwrap_tensor_subclasses
    xs_inner.extend(flatten_subclass(typing.cast(Tensor, x), meta))
  File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 206, in flatten_subclass
    tensors_and_sizes.extend(flatten_subclass(inner_tensor, inner_meta))
  File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 221, in flatten_subclass
    assert len(stride) == len(symint_placeholders)
AssertionError

@guilhermeleobas
Copy link
Collaborator Author

The fails for me, with:
File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in inner_fn
unwrapped_args = runtime_unwrap_tensor_subclasses(
File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 239, in runtime_unwrap_tensor_subclasses
xs_inner.extend(flatten_subclass(typing.cast(Tensor, x), meta))
File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 206, in flatten_subclass
tensors_and_sizes.extend(flatten_subclass(inner_tensor, inner_meta))
File "/home/hirsheybar/local/c/pytorch/torch/_functorch/_aot_autograd/subclass_utils.py", line 221, in flatten_subclass
assert len(stride) == len(symint_placeholders)
AssertionError

Oops. It was just one small mistake where I mistype the name of a variable. Should be good now.

[ghstack-poisoned]
@guilhermeleobas
Copy link
Collaborator Author

I just notice @IvanKobzarev PR (#138498) does a micro optimization on unwrap_tensor_subclasses, and since this PR changes this function a lot, I can adapt the code to include his changes. What do you think?

@bdhirsh
Copy link
Collaborator

bdhirsh commented Oct 25, 2024

ah yes that would be great. @IvanKobzarev has been looking into subclass runtime overhead, and it would be nice if we can avoid this PR making it too much worse

[ghstack-poisoned]
[ghstack-poisoned]
@guilhermeleobas
Copy link
Collaborator Author

ah yes that would be great. @IvanKobzarev has been looking into subclass runtime overhead, and it would be nice if we can avoid this PR making it too much worse

@IvanKobzarev, did you use any code to benchmark #138498? If so, can you share it with me?


if subclass_metas is None:
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
get_plain_tensors(typing.cast(Tensor, x), out_append_list=xs_inner)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subclass: Tensor, out_append_list: Optional[List[Tensor]] = None
) -> List[Tensor]:
subclass: Tensor, out_append_list: Optional[List[Union[Tensor, int, SymInt]]] = None
) -> List[Union[Tensor, int, SymInt]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, the type signature here is a bit confusing, since we never actually append ints/SymInts to the list in this function. I guess you needed this because in the out_append_list= case, the list we pass in might have symints in it already?

Instead, what do you think of: just refactoring this function to always accept an output list to append to, and mandating that anybody using this API must pass in their own list (from a quick grep there are only 2 call sites of this function, both within AOTAutograd)

Copy link
Collaborator

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few more comments, but otherwise I think this is ready to land. Thanks for all the hard work!

[ghstack-poisoned]
[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor

ah yes that would be great. @IvanKobzarev has been looking into subclass runtime overhead, and it would be nice if we can avoid this PR making it too much worse

@IvanKobzarev, did you use any code to benchmark #138498? If so, can you share it with me?

Hi,
Sorry for delay with reply.

At the moment I use for profiling :
1/ Not landed PR #136478 which uses James's profiling and then in the test I can take the times from logger.

And just manual average counting of time.time_ns() in global variable of unwrap_tensor_subclasses() in runtime_wrappers.py

[ghstack-poisoned]
@guilhermeleobas
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mlazos
Copy link
Contributor

mlazos commented Nov 5, 2024

Hi @mlazos, it is. But there's one test that it is failing if I remove the maybe_enable_thunkify call. I'll sync with @bdhirsh tomorrow.

@guilhermeleobas can that call be removed now? I think it's still there with the note that it can be removed after this PR closed.

@bdhirsh
Copy link
Collaborator

bdhirsh commented Nov 5, 2024

oh @mlazos our current hypothesis is that this context manager was only needed because there were some tests that did tensor * nested_int compute in a compiled region, which @jbschlosser has since banned as part of #138496 (independently of this PR). So I think it's worth a try to kill that code and see if CI is green

@mlazos
Copy link
Contributor

mlazos commented Nov 5, 2024

Awesome I will try that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile + dynamic shapes + tensor subclass graph output is broken

9 participants