-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[WIP] Rolling KV cache for autoregressive generation #12773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Some initial thoughts:
I think this two-level pattern (RollingKVCache as container + CacheLayer as layer cache) is sufficient to support at least all 3 use cases we've seen: SANA/KREA/GLM-Image, so this is good. We want to keep the API consistent across implementations, e.g., all layer cache classes would have update/reset/get methods and the container would have __getitem__, reset_state, etc. Not sure if we need a base class though - it might be sufficient for maintainers to enforce consistency during PR reviews.
Where should this live? This isn't a hook, so should be moved out. Not sure if we want a central place or keep them in individual model files. some of the caches are pretty model-specific
cc @DN6 - can you take a look and share some initial thoughts too? I'd like to move quickly with a simple & flexible design and we can adapt as we learn more from our use cases.
yeah, it should be moved out. I think with more models released, it can be stored in a one file and make a base class. Right now the base class would be kinda useless if we have only three models and each one is very specific on how caching is done So, storing in model file and keeping an eye on consistent implementation for new models sounds good to me. We can generalize easily if API is identical for new models |
What does this PR do?
Fixes #12600
Functionality-wise the self attention cache seems to work correctly, cross-attention has to be added and verfied. I added Krea to test the cache though I am not getting the same output as the original model yet. From quick debugging, looked to be related to timesteps or rope embeddings. Opening a draft as a reminder to myself to give this feature higher priority