Skip to content

[Assistant Generation] Improve Encoder Decoder#26701

Merged
patrickvonplaten merged 13 commits intomainfrom
improve_assistant_generation_enc_dec
Oct 11, 2023
Merged

[Assistant Generation] Improve Encoder Decoder#26701
patrickvonplaten merged 13 commits intomainfrom
improve_assistant_generation_enc_dec

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 9, 2023

What does this PR do?

This PR speeds up assistant generation / speculative decoding for encoder-decoder models such as Distill-Whisper by ~20-30%.

Improvements:

  • If assistant and model share same encoder, let's allow the user to pass assistant_encoder_outputs so that the inputs are not encoded twice (gives ~20% speed-up)
  • In the small loop I don't think we have to allocate tensors for the attention mask all the time. This is done automatically by the model if necessary (gives ~3,4% speed-up)
  • The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?

@patrickvonplaten patrickvonplaten marked this pull request as draft October 9, 2023 17:53
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# if n_matches == int(assistant_model.max_assistant_tokens):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?

# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this @gante ? Allocating new memory here every time leads to some slow downs that are not insignificant for Distil Whisper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense to remove, since it is the default attention mask! 👍

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 9, 2023

The documentation is not available anymore as the PR was closed or merged.

new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a bug previously. We forgot to apply the logits processors to the last logit here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 👀

probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary here as new_logits is already sliced

@patrickvonplaten patrickvonplaten marked this pull request as ready for review October 10, 2023 12:20
Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, thank you for the upgrades @patrickvonplaten 🔥

Only added two minor, optional nits.


> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)

max_assistant_tokens (`int`, *optional*, defaults to 5):
Copy link
Contributor

@gante gante Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can take the chance to give a better name to this variable: assistant_tokens or similar. max_assistant_tokens implies that the assistant will never cross this limit but, as we can see in max_assistant_tokens_schedule (which should also be renamed accordingly), that is not true :)

Poor original naming choice by me :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
if hasattr(assistant_model, "max_assistant_tokens"):
warnings.warn(
"Setting `max_assistant_tokens` via `assistant_model.max_assistant_tokens` is deprecated and will be removed in v5. Make sure to set `max_assistant_tokens` via the generation_config instead.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can deprecate this earlier (like in v4.37)?

I haven't seen users fiddling with this internal variable :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense to remove, since it is the default attention mask! 👍

new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 👀

@patrickvonplaten patrickvonplaten changed the title [Assistant Generation] Improve enc dec [Assistant Generation] Improve Encoder Decoder Oct 11, 2023
@patrickvonplaten
Copy link
Contributor Author

The failing Hub test seems to be flaky.

This PR is ready for a final review.

inputs_embeds = self.embed_tokens(input) * self.embed_scale

if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug previously. The attention_mask should be equal to input_embeds + past_key_values length.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean thanks for improving the performances!

Comment on lines +240 to +241
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if the schedule parameters should be hard coded but fine for me

@patrickvonplaten patrickvonplaten merged commit da69de1 into main Oct 11, 2023
@patrickvonplaten patrickvonplaten deleted the improve_assistant_generation_enc_dec branch October 11, 2023 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants