Skip to content

Make Cache a subclass of torch.Tensor#35792

Closed
IlyasMoutawwakil wants to merge 19 commits intomainfrom
tensor-cache
Closed

Make Cache a subclass of torch.Tensor#35792
IlyasMoutawwakil wants to merge 19 commits intomainfrom
tensor-cache

Conversation

@IlyasMoutawwakil
Copy link
Member

What does this PR do?

Both torch script tracing and torch dynamo/fx have restrictions on input types (torch script has more) which makes the export fail as one torch module (the model) is passing another (the cache) around as its input. Having Cache be a subclass of torch.Tensor bypasses these issues and imo makes more sense as the Cache class has no forward and is just a container of torch tensors.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle LGTM. I'm calling up the torch.export<>transformers expert to review to double-check these changes are also okay for that goal as well 🤗

Question: Cache object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?

Double-checks:

  1. Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to main? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vv and RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv)
  2. Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )

{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for optimum and you're part of optimum, so I'm assuming it's okay :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure why this is was needed as well, tagging @echarlaix @mht-sharma for more info

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding @michaelbenayoun who worked on this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to be able to record Cache related operations in the fx graph. If another easier solution has been found, I'm all for it.

@gante
Copy link
Contributor

gante commented Jan 20, 2025

@guangy10 as requested on Slack, have a look if you're available 🙏

@guangy10
Copy link
Contributor

For the correctness testing, no extensive testing, but we do have some correctness guarantee for supported models test_export_static_cache (pointer). Can you run slow tests on this PR?

Also I'm not exactly sure if the StaticCache will be functioning as expected. Because with nn.Module the Cache is registered as a mutable buffer and lifted to the graph input during export. I'm curious how it works with tensor subclass. It seems like tensor subclasses do not directly support buffer registration like nn.Module does. Can we compare the graph between using the nn.Module solution vs. the tensor subclass solution.

Alternatively, since the motivation is to handle the legacy torch script tracing (I assume the traffic to this path will be lower and lower over time), would it be a cleaner separation if we create a dedicated Cache subclass for it but keeping the one for pytorch2.0+ as nn.Module? No need to maintain compatibility to the torch script solution.

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 22, 2025

Question: Cache object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?

Shouldn't be an issue as we're not using the _make_subclass() but rather _make_wrapper_subclass(), the difference is explained by @albanD:

These two functions do quite different things. The main difference is that when you do _make_subclass(), the current object is a honest to goodness Tensor with data in its storage and everything. When you do _make_wrapper_subclass(), the current object has no data and it is expected that some field on the Tensor will be another Tensor (hence the outer one being called wrapper) that contains real data.
in https://dev-discuss.pytorch.org/t/whats-the-difference-between-torch-tensor-make-subclass-and-torch-tensor-make-wrapper-subclass/1839

One example is the QuantizedTensor subclass which has two dtypes (a public one qt.dtype and an internal one qt._data.dtype

Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to main? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vv and RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv)
Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )

Running them right now (btw is there a way to trigger them on the CI ?), I was only running llama fast tests and llama+executorch integration tests.

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Jan 22, 2025

Edit: confirmed these two tests fail on main as well

Running RUN_SLOW=1 pytest tests/models/llama/test_modeling_llama.py -vv give two errors which I guess are related the machine I'm testing on (A100 vs the A10 that's used in the CI) ;

FAILED tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_llama_3_1_hard - AssertionError: 'Tell[74 chars]ical social and political upheaval in France t[557 chars]s.\n' != 'Tell[74 chars]ical political...
FAILED tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 - AssertionError: False is not true

in the first social and political is reversed to political and social :

E       AssertionError: 'Tell[74 chars]ical social and political upheaval in France t[557 chars]s.\n' != 'Tell[74 chars]ical political and social upheaval in France t[557 chars]s.\n'
E       Diff is 1259 characters long. Set self.maxDiff to None to see it.

in the second the assertion is not verbose enough:

>       self.assertTrue(
            torch.allclose(
                EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device),
                out.logits.float().mean(-1),
                atol=1e-2,
                rtol=1e-2
            )
        )
E       AssertionError: False is not true

adding some verbosity:

E       AssertionError: False is not true : Expected: tensor([[-6.5208, -4.1218, -4.9377, -3.2536,  0.8127, -2.9811,  1.2918, -3.3848]],
E              device='cuda:0')
E       Got: tensor([[-6.5081, -4.1175, -4.9761, -3.1678,  0.8199, -3.0029,  1.2809, -3.3309]],
E              device='cuda:0')

@IlyasMoutawwakil
Copy link
Member Author

Thanks everyone, I'm replacing this PR with #35873 that's less restrictive.

@guangy10
Copy link
Contributor

Thanks everyone, I'm replacing this PR with #35873 that's less restrictive.

Yeah, this one looks much cleaner. Do you mind rerun the slow export/executorch tests on this PR?

After that, can you run cross-repo integration tests in Optimum (running it locally is fine) ? @echarlaix recently moved ExecuTorch to a new repo huggingface/optimum-executorch, however, all documentations/tutorials are deleted unintentionally after the move.

It used to be as simple as:

pip install optimum[exporters-executorch]

Override the installed transformers version to your dev version including this PR, then simply run

pytest executorch/*/test_*.py -s -vvvv --durations=0

@echarlaix @michaelbenayoun can you guide @IlyasMoutawwakil how to run executorch e2e tests?

@IlyasMoutawwakil
Copy link
Member Author

Do you mind rerun the slow export/executorch tests on this PR?

They pass, I made sure both llama and cache slow tests are passing on this PR before dropping it so that it could be used for future references when subclassing the tensor class.

After that, can you run cross-repo integration tests in Optimum (running it locally is fine) ?

Can do that later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants