Make Cache a subclass of torch.Tensor#35792
Conversation
1114e7e to
b67b6eb
Compare
gante
left a comment
There was a problem hiding this comment.
In principle LGTM. I'm calling up the torch.export<>transformers expert to review to double-check these changes are also okay for that goal as well 🤗
Question: Cache object holds a list of tensors, usually with a pair of tensors per layer. On some cases, we can have different tensors of a cache on different devices. Would this conflict with the new inheritance?
Double-checks:
- Have you confirmed that slow llama tests and slow cache tests have no regressions with respect to
main? (RUN_SLOW=1 py.test tests/models/llama/test_modeling_lama.py -vvandRUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv) - Have you confirmed that llama + static cache + compilation preserves throughput? (can share a script if needed :) )
| {}, | ||
| proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), | ||
| ) | ||
| # def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: |
There was a problem hiding this comment.
This is for optimum and you're part of optimum, so I'm assuming it's okay :D
There was a problem hiding this comment.
Yeah I'm not sure why this is was needed as well, tagging @echarlaix @mht-sharma for more info
There was a problem hiding this comment.
It was to be able to record Cache related operations in the fx graph. If another easier solution has been found, I'm all for it.
|
@guangy10 as requested on Slack, have a look if you're available 🙏 |
|
For the correctness testing, no extensive testing, but we do have some correctness guarantee for supported models Also I'm not exactly sure if the Alternatively, since the motivation is to handle the legacy torch script tracing (I assume the traffic to this path will be lower and lower over time), would it be a cleaner separation if we create a dedicated Cache subclass for it but keeping the one for pytorch2.0+ as |
Shouldn't be an issue as we're not using the
One example is the QuantizedTensor subclass which has two dtypes (a public one
Running them right now (btw is there a way to trigger them on the CI ?), I was only running llama fast tests and llama+executorch integration tests. |
|
Edit: confirmed these two tests fail on main as well Running in the first in the second the assertion is not verbose enough: adding some verbosity: |
5829a6a to
da60604
Compare
e4534a0 to
5ccb79c
Compare
|
Thanks everyone, I'm replacing this PR with #35873 that's less restrictive. |
Yeah, this one looks much cleaner. Do you mind rerun the slow export/executorch tests on this PR? After that, can you run cross-repo integration tests in Optimum (running it locally is fine) ? @echarlaix recently moved ExecuTorch to a new repo It used to be as simple as: Override the installed @echarlaix @michaelbenayoun can you guide @IlyasMoutawwakil how to run executorch e2e tests? |
They pass, I made sure both llama and cache slow tests are passing on this PR before dropping it so that it could be used for future references when subclassing the tensor class.
Can do that later. |
What does this PR do?
Both torch script tracing and torch dynamo/fx have restrictions on input types (torch script has more) which makes the export fail as one torch module (the model) is passing another (the cache) around as its input. Having Cache be a subclass of torch.Tensor bypasses these issues and imo makes more sense as the Cache class has no forward and is just a container of torch tensors.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.