[Assistant Generation] Improve Encoder Decoder#26701
Conversation
src/transformers/generation/utils.py
Outdated
| assistant_model.max_assistant_tokens += 2.0 | ||
| else: | ||
| assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) | ||
| # if n_matches == int(assistant_model.max_assistant_tokens): |
There was a problem hiding this comment.
The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?
| # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) | ||
| new_token_len = candidate_input_ids.shape[1] - prev_seq_len | ||
| assist_inputs = candidate_input_ids[:, -new_token_len:] | ||
| assist_attn = torch.ones_like(candidate_input_ids) |
There was a problem hiding this comment.
Do we really need this @gante ? Allocating new memory here every time leads to some slow downs that are not insignificant for Distil Whisper
There was a problem hiding this comment.
Yes, this makes sense to remove, since it is the default attention mask! 👍
|
The documentation is not available anymore as the PR was closed or merged. |
| new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present | ||
| if len(logits_processor) > 0: | ||
| for i in range(candidate_length): | ||
| for i in range(candidate_length + 1): |
There was a problem hiding this comment.
That was a bug previously. We forgot to apply the logits processors to the last logit here
| probs = new_logits.softmax(dim=-1) | ||
| selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||
| else: | ||
| selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) |
There was a problem hiding this comment.
This is unnecessary here as new_logits is already sliced
…to improve_assistant_generation_enc_dec
There was a problem hiding this comment.
Fantastic, thank you for the upgrades @patrickvonplaten 🔥
Only added two minor, optional nits.
|
|
||
| > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) | ||
|
|
||
| max_assistant_tokens (`int`, *optional*, defaults to 5): |
There was a problem hiding this comment.
Perhaps we can take the chance to give a better name to this variable: assistant_tokens or similar. max_assistant_tokens implies that the assistant will never cross this limit but, as we can see in max_assistant_tokens_schedule (which should also be renamed accordingly), that is not true :)
Poor original naming choice by me :D
There was a problem hiding this comment.
Sounds good!
src/transformers/generation/utils.py
Outdated
| assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls | ||
| if hasattr(assistant_model, "max_assistant_tokens"): | ||
| warnings.warn( | ||
| "Setting `max_assistant_tokens` via `assistant_model.max_assistant_tokens` is deprecated and will be removed in v5. Make sure to set `max_assistant_tokens` via the generation_config instead.", |
There was a problem hiding this comment.
Perhaps we can deprecate this earlier (like in v4.37)?
I haven't seen users fiddling with this internal variable :)
There was a problem hiding this comment.
Ok for me!
| # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) | ||
| new_token_len = candidate_input_ids.shape[1] - prev_seq_len | ||
| assist_inputs = candidate_input_ids[:, -new_token_len:] | ||
| assist_attn = torch.ones_like(candidate_input_ids) |
There was a problem hiding this comment.
Yes, this makes sense to remove, since it is the default attention mask! 👍
| new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present | ||
| if len(logits_processor) > 0: | ||
| for i in range(candidate_length): | ||
| for i in range(candidate_length + 1): |
…to improve_assistant_generation_enc_dec
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
…into improve_assistant_generation_enc_dec
|
The failing Hub test seems to be flaky. This PR is ready for a final review. |
| inputs_embeds = self.embed_tokens(input) * self.embed_scale | ||
|
|
||
| if attention_mask is None: | ||
| attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) |
There was a problem hiding this comment.
This was a bug previously. The attention_mask should be equal to input_embeds + past_key_values length.
ArthurZucker
left a comment
There was a problem hiding this comment.
Very clean thanks for improving the performances!
| - `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else | ||
| reduce by 1 |
There was a problem hiding this comment.
wondering if the schedule parameters should be hard coded but fine for me
What does this PR do?
This PR speeds up assistant generation / speculative decoding for encoder-decoder models such as Distill-Whisper by ~20-30%.
Improvements:
assistant_encoder_outputsso that the inputs are not encoded twice (gives ~20% speed-up)