add attention_mask and position_ids in assisted model#26892
add attention_mask and position_ids in assisted model#26892gante merged 14 commits intohuggingface:mainfrom
Conversation
|
Hi @jiqing-feng 👋 I agree in principle with the changes that you are proposing, but you probably need to do a few changes to make our CI go green :) |
|
Hi @gante . I use |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
gante
left a comment
There was a problem hiding this comment.
Added a few nits -- after those are addressed, we're ready to merge :)
|
Hi @gante . Would you please review it again? Thx! |
gante
left a comment
There was a problem hiding this comment.
Thank you for iterating! 💛
src/transformers/generation/utils.py
Outdated
| else: | ||
| input_ids_len = assistant_inputs["input_ids"].shape[-1] | ||
|
|
||
| if input_ids_len not in (0, 1): |
There was a problem hiding this comment.
| if input_ids_len not in (0, 1): | |
| if input_ids_len not in (1, 2): |
|
@jiqing-feng Ah, actually I have two requests before asking for the green light of a core maintainer:
|
|
Hi @gante . I tested it on my CPU device since the GPU is unavailable to me. The new branch is a little faster (around 3%) than the main branch. The test script is as follows, feel free to test it on both GPU and CPU. from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
prompt = "Speculative decoding is"
checkpoint = "bigscience/bloom-7b1"
assistant_checkpoint = "bigscience/bloom-560m"
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
generation_kwargs = {"do_sample": False, "max_new_tokens": 64, "temperature": 1.0, "top_p": 1.0, "num_beams": 1}
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
for i in range(5):
start = time.time()
outputs = model.generate(**inputs, assistant_model=assistant_model, **generation_kwargs)
end = time.time()
new_tokens = outputs.shape[-1] - inputs["input_ids"].shape[-1]
print(f"Assistant decoding latency per token is {(end-start)/new_tokens * 1000} ms")
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) |
Hi @gante . Could you have a look at this? Thx! |
|
Hi @jiqing-feng Running on my end ( i.e. the newly generated masks that are appended must be created in the same device as the existing mask :) |
|
Hi @gante . Would you please try it again? It should be fixed and I also tested it on A100, the results and performance are exactly the same. BTW, the failed test seems not related to my changes. |
|
@jiqing-feng perfect, all works well on my end. Two related notes:
👉 you will need to rebase your changes to fix both issues, but only after the PR linked above gets merged. You may get minor rebase issues due to 2., but they should be trivial to fix After that is done, I'll tag a core maintainer for a final quick check :) |
Hi @gante . I removed |
🤦 my apologies, you're absolutely right. In that case, rebasing to get the CI green is all you need to do. Tagging a core maintainer for a quick final check :) |
gante
left a comment
There was a problem hiding this comment.
Good to go, thank you for iterating with me 💛
(Note: results also validated on my end, no slowdown nor generative performance drop)
Hi @gante . I see this PR you mentioned has been merged and my PR is already up to date, but some of the CI are still red. |
|
@jiqing-feng There were some unexpected failures because of new package releases - thankfully not related to this PR! They should now be resolved on main - rebasing should fix them here.
Yes, I meant to add a test to the CI runs. It looks like it should be tested in tests/generation/test_utils.py - but I'll let @gante confirm |
|
(woops, wrong button) |
|
@amyeroberts not sure if we can test this feature reliably: there are no output differences, since assisted generation always outputs what the main model dictates and this PR only modifies the assistant model's inputs to be more aligned with the main model's. What we should see on average is a higher speedup with masked inputs, as the assistant model will receive the same inputs and thus has a higher chance of matching the main model, but that is far guaranteed for all calls. A speed test would be very flaky 🤔 |
|
@gante I understand - I wasn't clear enough before. Really all I was looking for it to make sure that this can be safely used for different assistant models i.e. can I pass in a decoder-only model? How about encoder-decoder. So not speed or values, just API |
|
@amyeroberts we do have Mixin tests (e.g.), so any issue regarding API should have been caught there :) |
|
@gante Sweet - in that case it's all good 👍 Re the failing tests - there's some PRs due to be merge which should (hopefully, this time) resolve the issues we've been having |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks again for adding!
|
Hi, @gante @amyeroberts . All CI are green. I think it is time to merge : ) |
|
@jiqing-feng thank you for iterating with us and making |
|
@amyeroberts @jiqing-feng There are currently some unexpected CI failures caused by |
Hi, @VsonicV . Sorry for the failed CI. It is weird that I can successfully run pytest in my local repo (which has updated to origin/main). I see that your CI failed at |
|
Hi, @jiqing-feng, thanks for the quick check. This happened exactly the same for me: I can run |
I submitted a new PR, and all CI passed. Would you apply my PR and see if the CI is ok? Furthermore, it is worth a try that update your repo by merging the origin/main and pushing these updates to rerun the CI. |
|
@jiqing-feng Hi, thanks for this prompt fix! I will rebase my PR and re-do the CI checks after your new PR is merged. Fingers crossed! |
|
This PR broke speculative decoding for Whisper, can we maybe revert it for now? |
This reverts commit 184f60d.
|
Issue reported here: https://huggingface.co/openai/whisper-large-v3/discussions/20 |
…ngface#27523) * Revert "add attention_mask and position_ids in assisted model (huggingface#26892)" This reverts commit 184f60d. * more debug
Hi @gante
Do you think that we should also add
assistant_attention_maskandassistant_position_idsinassisted_decoding? I see that the original model hasattention_maskandposition_ids(in most models) in the model inputs but the assistant model has no these kinds of input.If you think it is okay to align the inputs of the original model and the assistant model, maybe we can find a more elegant way to integrate it. Thx!