Generate: Fix modern llm generate calls with synced_gpus#34095
Generate: Fix modern llm generate calls with synced_gpus#34095gante merged 5 commits intohuggingface:mainfrom
generate calls with synced_gpus#34095Conversation
|
@SunMarc this should help with FSDP + |
0f6482a to
262c971
Compare
| # This is needed if return_dict_in_generate is True | ||
| start_from_empty_dynamic_cache = False | ||
| past_key_values = model_kwargs.get("past_key_values", None) | ||
| if isinstance(past_key_values, DynamicCache) or ( | ||
| isinstance(past_key_values, EncoderDecoderCache) | ||
| and isinstance(past_key_values.self_attention_cache, DynamicCache) | ||
| ): | ||
| if past_key_values.get_seq_length() == 0: | ||
| start_from_empty_dynamic_cache = True | ||
|
|
There was a problem hiding this comment.
Simplifies logic in assisted generation: see the new is_first_iteration variable and its uses :)
ArthurZucker
left a comment
There was a problem hiding this comment.
The decorelation between prepare input for generation and the modeling is very nice.
I don't know how well we test this, if the slow CIs were crying or not, but if yes, then it's already tested and Good to go!
|
This fixes the error I was seeing here: Thank you so much! |
|
@ArthurZucker I don't think this is being tested! @SunMarc -- I couldn't find any related test, but multigpu tests have a more elaborated setup, so I could be missing something. Can you confirm? Meanwhile, I'm merging since this PR unblocks users. If there is no test, I'll open a follow-up PR :) |
I'm not aware of any tests related to multi-gpu and generate with sync_gpus=True. I will have a look at this since we also need to add them for deepspeed and fdsp ! cc @muellerzr |
|
Does this fully address the issue with When
cc: @gante |
|
@jiayuanmark the snippet in the PR header, which serves as a base test case, runs without problems 🤔 Would you be able to create a minimal reproducer for your issue, and open a new issue? |
What does this PR do?
Step 5 in #32685
Fixes #32885
Fixes #32603
Fixes #32641
Modern LLMs, i.e. LLMs that support our cache classes, currently fail when the input has a
batch size > 1andsynced_gpus = True.On
main, this is what happens withsynced_gpuscache_positionstops being updated when generation finishes in a given device, causing cache indexing errors on that device (the cache continues growing because we keep doing dummy forward passes)cache_position, then slicinginput_idsgets out of bounds for the dummy computations (we stop updatinginput_ids, so it stops growing)This PR makes the changes to enable generation with the behavior above.
💛 Please note that, because of the efforts in #32685, updating model input preparation requires an update in a single function, as opposed to an update per model 💛
Test script (call with 2+ GPUs) that fails before this PR (from this comment):