Skip to content

Compile compatibilty for decoder-only models#32617

Merged
zucchini-nlp merged 6 commits intohuggingface:mainfrom
zucchini-nlp:compile-models
Sep 9, 2024
Merged

Compile compatibilty for decoder-only models#32617
zucchini-nlp merged 6 commits intohuggingface:mainfrom
zucchini-nlp:compile-models

Conversation

@zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Aug 12, 2024

What does this PR do?

Recently we merged a few PRs deprecating old-style cache in all decoder-only models. This PR is a continuation of it, here we verify that all newly deprecated models can support static cache and are compatible with torch.compile. The main change is in RoPE to get rid of dynamic control flow

A few exception that cannot be supported yet: MoE models and some other with dynamic control flow like Phi3 or Chameleon.

Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.

In the next PR I can start deprecating old cache in encoder-decoder models starting from Bart and GPT models

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments, mostly about aligning with llama

Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.

💛

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing #Copied from ... ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, bloom has alibi and needs 2D attention for that. So we can't expand it to 4D, and choose to append zeros to attn to make it static shape.

@zucchini-nlp
Copy link
Member Author

Updated with @gante comments and used the new RoPE modeling in all models. Ready for review!

@zucchini-nlp
Copy link
Member Author

Failing tests are not related

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💎 thanks so much for this tedious work, well done 🥳
What is left is to make sure the compile tests pass !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support compile ? (not seeing the supports_static_cache

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is potentially breaking no? (no more offset)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm right, lemme check this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just verified we don't need to slice anymore, because we apply rope directly on the curretn position. Prev we applied Rope for all positions up to the current and had to slice out cached positions

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
kv_seq_len = key_states.shape[-2] + cache_position[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't remember why we don't use cache_position[-1]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the last position is the whole past kv length, which causes incorrect length in pre-fill or uncached generation. Maybe we should switch to simply past_length = cache_position[-1] everywhere?

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for these very laborious changes 🙏

@zucchini-nlp
Copy link
Member Author

@simonJJJ I added the new RoPE embedding for Qwen2-VL in this PR. Since I changes Qwen2, the changes were automatically propagated with copy statements. I remember you had a PR to fix RoPE for FA2 can you check if the current version works as you expect?

@zucchini-nlp
Copy link
Member Author

@ArthurZucker @gante changed deprecation to v4.46 and added qwen2-VL. Ran the tests again to check everything is okey. Let me know if you have any comments


# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in Llama, using cache_position is a dynamic control flow which is not supported currently by compile. The fullgraph-compile test fails without this change

@gante
Copy link
Contributor

gante commented Sep 6, 2024

@zucchini-nlp happy with the changes, feel free to merge! (given that you mentioned that you re-ran the tests 💛 )

@zucchini-nlp
Copy link
Member Author

Yes, was exactly thinking to rebase main and re-ran tests one more time

@zucchini-nlp
Copy link
Member Author

Test are passing, including slow. So, merging

@zucchini-nlp zucchini-nlp merged commit 65bb284 into huggingface:main Sep 9, 2024
@anijain2305
Copy link
Contributor

Can we update the tracker in #28981

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* squash into one commit

* add qwen2-vl for rope standardization

* fix mistral compile

* fix qwen2-vl

* fix-copies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants