Compile compatibilty for decoder-only models#32617
Compile compatibilty for decoder-only models#32617zucchini-nlp merged 6 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
Added a few comments, mostly about aligning with llama
Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.
💛
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
| if inputs_embeds is not None and cache_position[0] == 0: | ||
| model_inputs = {"inputs_embeds": inputs_embeds} | ||
| model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
There was a problem hiding this comment.
not really, bloom has alibi and needs 2D attention for that. So we can't expand it to 4D, and choose to append zeros to attn to make it static shape.
|
Updated with @gante comments and used the new RoPE modeling in all models. Ready for review! |
|
Failing tests are not related |
There was a problem hiding this comment.
does it support compile ? (not seeing the supports_static_cache
There was a problem hiding this comment.
yes, it does. You might have missed it :)
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): |
There was a problem hiding this comment.
this is potentially breaking no? (no more offset)
There was a problem hiding this comment.
Hmm right, lemme check this
There was a problem hiding this comment.
update: just verified we don't need to slice anymore, because we apply rope directly on the curretn position. Prev we applied Rope for all positions up to the current and had to slice out cached positions
| if past_key_value is not None: | ||
| # Activate slicing cache only if the config has a value `sliding_windows` attribute | ||
| cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 | ||
| kv_seq_len = key_states.shape[-2] + cache_position[0] |
There was a problem hiding this comment.
i don't remember why we don't use cache_position[-1]
There was a problem hiding this comment.
Because the last position is the whole past kv length, which causes incorrect length in pre-fill or uncached generation. Maybe we should switch to simply past_length = cache_position[-1] everywhere?
gante
left a comment
There was a problem hiding this comment.
Thank you for these very laborious changes 🙏
b3c91c0 to
1f328f0
Compare
|
@simonJJJ I added the new RoPE embedding for Qwen2-VL in this PR. Since I changes Qwen2, the changes were automatically propagated with |
|
@ArthurZucker @gante changed deprecation to v4.46 and added qwen2-VL. Ran the tests again to check everything is okey. Let me know if you have any comments |
|
|
||
| # cache_position must be valid here no matter which cache we use | ||
| past_seen_tokens = cache_position[0] if past_key_values is not None else 0 | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
There was a problem hiding this comment.
Same as in Llama, using cache_position is a dynamic control flow which is not supported currently by compile. The fullgraph-compile test fails without this change
|
@zucchini-nlp happy with the changes, feel free to merge! (given that you mentioned that you re-ran the tests 💛 ) |
|
Yes, was exactly thinking to rebase main and re-ran tests one more time |
|
Test are passing, including slow. So, merging |
|
Can we update the tracker in #28981 |
* squash into one commit * add qwen2-vl for rope standardization * fix mistral compile * fix qwen2-vl * fix-copies
What does this PR do?
Recently we merged a few PRs deprecating old-style cache in all decoder-only models. This PR is a continuation of it, here we verify that all newly deprecated models can support static cache and are compatible with torch.compile. The main change is in RoPE to get rid of dynamic control flow
A few exception that cannot be supported yet: MoE models and some other with dynamic control flow like Phi3 or Chameleon.
Ran
test_generate_compile_fullgraphandtest_static_cache_matches_dynamicon all models + ran slow tests on models touched by this PR.In the next PR I can start deprecating old cache in encoder-decoder models starting from Bart and GPT models