Refactor MambaCache to modeling_mamba.py#38086
Refactor MambaCache to modeling_mamba.py#38086manueldeprada merged 78 commits intohuggingface:mainfrom
MambaCache to modeling_mamba.py#38086Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
LGTM, added a few tiny nits 👍
|
@manueldeprada have you run (after we confirm slow tests are okay, let's tag arthur) |
Ran the tests locally, there are some failures that already happened before the PR: will run in CI once I have permissions. |
|
Same failures on CI than locally @gante . Let me know if those slow tests should be fixed, ignored or deleted. |
MambaCache to modeling_mamba.py (parity with Zamba)MambaCache to modeling_mamba.py
| ): | ||
| # Overwritten -- uses `cache_params` as opposed to `past_key_values` | ||
|
|
||
| if use_cache and cache_params is None: |
There was a problem hiding this comment.
this 2 new additions are needed for BC and affect all classes that use MambaCache: it emulates the order in which generate() initialized MambaCache. See here
There was a problem hiding this comment.
I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed
| selective_scan_fn, | ||
| causal_conv1d_fn, | ||
| causal_conv1d_update, | ||
| mamba_inner_fn, # modular: no_replace |
There was a problem hiding this comment.
if (IF) we have the pragma, should not appear here
| return updated_node | ||
|
|
||
| def leave_ImportFrom(self, original_node, updated_node): | ||
| """The imports from other file types (configuration, processing etc) should use original model name.""" |
There was a problem hiding this comment.
Skipping rename for absolute imports does not affect other models and is probably a reasonable assumption.
Relative imports need to be renamed: from .configutation_mamba import ...MambaConfig
ArthurZucker
left a comment
There was a problem hiding this comment.
A few small things left!
| ): | ||
| # Overwritten -- uses `cache_params` as opposed to `past_key_values` | ||
|
|
||
| if use_cache and cache_params is None: |
There was a problem hiding this comment.
I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed
|
[For maintainers] Suggested jobs to run (before merge) run-slow: falcon_mamba, mamba, mamba2 |
|
Thanks, @ArthurZucker! I ended up refactoring the entire lmk if theres anything left! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, that looks better 😉
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This PR moves the specialized
MambaCacheclass fromcache_utils.pytosrc/transformers/models/mamba/modeling_mamba.py. This is preliminary work for #38077.Changes:
MambaCacheto its own file, aligning with Zamba, Bamba, etc.generate.preparing_inits_from_generationfromforward(). See this comment.preparing_inits_from_generationand only Mamba, Mamba2, and FalconMamba were doing it inforward, which is bad for torch.compile.