[Whisper] Add sequential longform decoding#27492
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
…to add_whisper_seq_gen
…to add_whisper_seq_gen
|
QQ: @patrickvonplaten - wouldn't concatenating and passing the whole audio as input result in exploding GPU VRAM usage? |
|
Hey @patrickvonplaten would you mind adding the performance for bigger models? The worst the model is at predicting timestamps, the worse the performances of the chuncked algorithm. I remember observing very little loss for large models! (Just as a FMI!) |
The audio is chunked on the fly (there is a while loop now) |
Sure I can run it for larger models as well. I'm not 100% sure though why this matters - if we see such strong gains for smaller models we should add it nevertheless. |
…to add_whisper_seq_gen
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
gante
left a comment
There was a problem hiding this comment.
Awesome feature! Looking forward to the other upgrades mentioned in section 4.5 of the paper 🔥 The "previous text conditioning" probably can benefit from the newly added ability to return and reuse past_key_values
Added a few minor nits
| >>> # transcribe audio to ids | ||
| >>> generated_ids = model.generate(**inputs) | ||
|
|
||
| >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) |
There was a problem hiding this comment.
I agree that printing the entire generated output here would be a bad idea, but we could add something like
>>> len(transcription[0])That way, our doctests would fail if we start having numerical problems with the longform decoding. WDYT?
| elif not is_shortform: | ||
| if return_timestamps is False: | ||
| raise ValueError( | ||
| "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " |
There was a problem hiding this comment.
A few lines here have more than 120 characters :)
This particular sentence also has 2x "which", rewriting might improve readability
There was a problem hiding this comment.
There could be something wrong with the way I'm initialising the pipeline. But on my single file benchmark - it just truncates the output to 30 sec
Repro:https://github.com/Vaibhavs10/scratchpad/blob/main/conditional_long_form_generation_whisper.ipynb
Note: I'm not defining chunk size as it isn't defined in the example snippet up top.
It works as intended with model + generate tho! 🚀
More of an overall usage remark from a developer's PoV:
How do we clarify whether the transcription strategy used is chunked or conditional? Can we allow developers to choose? Supporting this via pipeline is important IMO.
Edit: To clarify, one of the biggest usecase for people to use pipeline is to throw an audio file in whichever format and then get the transcriptions for it.
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the hard work!
Would just split the 6.2 in a seperate function, not falling in the trap of havving a huge generate. Same for the small functions used for decoding
| """ | ||
|
|
||
| def __init__(self, generate_config): # support for the kwargs | ||
| def __init__( |
There was a problem hiding this comment.
Sorry if I am not understanding the statement 😅 No it was used for chunked processing because the merging algorithm heavily relies on timestamps, and also produces timestamps. See this line which always add the processor.
Placing a breakpoint in the call of this class with this:
from transformers import pipeline
from datasets import load_dataset
import numpy as np
transcriptor = pipeline("automatic-speech-recognition", model = "openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_file = ds[0]["audio"]["array"]
long_audio = np.concatenate([ds[0]["audio"]["array"]]*40)
out = transcriptor(long_audio, return_timestamps=True, chunk_length_s=30, stride_length_s=5)(this is for me "timestamp" chunk transcription.
| return_segments (`bool`, *optional*, defaults to `False`): | ||
| Whether to additionally return a list of all segments. Note that this option can only be enabled | ||
| when doing long-form transcription. |
There was a problem hiding this comment.
By segment do you mean segments of audio? (not sure I understand) This is only valid for sequential long form no?
| time_precision (`int`, *optional*, defaults to 0.02): | ||
| The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts | ||
| for 20 ms. |
There was a problem hiding this comment.
We have this in the pipeline:
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positionswhich should be replaced I guess?
There was a problem hiding this comment.
We don't have access to the feature extractor here sadly
| **kwargs, | ||
| ) | ||
| if generation_config.return_timestamps is True: | ||
| last_forced_decoder_ids = ( |
There was a problem hiding this comment.
Not sure I understand, if the last token is no_timestamps we should still have the first 2-3 forced tokens (specifically one set for the task) unless it's done after
| cur_bsz = prev_bsz = batch_size | ||
|
|
||
| # 6.2 Transcribe audio until we reach the end of all input audios | ||
| while (seek < max_frames).any(): |
There was a problem hiding this comment.
let's seperate this in another function?
(Kind of like the merging function)
There was a problem hiding this comment.
Factored out one part into a function
Nice catch, there was a typo. Added a test for the pipeline now as well. |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
|
I've adapted the tests to match the new time stamp logit processor. I've double-checked that the new time stamp logit processor gives the same WER results on long-form and applied suggestions. Failing test is a time-out which is not related to this PR - merging! |
|
Thanks for implementing. It seems that longform decoding doesn't work with return_token_timestamps=True for model.generate() (nor return_timestamps="word" for pipeline() ) in V4.37.2 Failling at line 822 of whisper/generation_whisper.py in the private method _postprocess_outputs with error "AttribeError: 'tutple' object has no attribute 'cpu' " |
|
Hi @antoinethl, could you open a new issue, detailing the error encountered (including full traceback) and a minimal reproducer? |
Hi, just opened 2 issues with traceback and short example. |
What does this PR do?
This PR adds the long-form transcription as originally proposed in the Whisper codebase: https://github.com/openai/whisper and in the paper: https://cdn.openai.com/papers/whisper.pdf
To better understand long-form transcription, please have a look at Section 4.5: Strategies for Reliable Long-form Transcription of the paper.
Before this PR transformers only had "chunked" long-form transcription which trades speed against accuracy (see Table below). In this PR we add the best-performing long-form transcription to Transformers.
Usage:
One can use long-form transcription now easily with the pipeline object simply passing long-form audio. Previously, long-form audio was truncated to just 30 seconds. This PR makes sure that long audio is not cut when passed to the pipeline:
The pipeline is great for "easy-to-set-up" code but lacks customization and readability. For example the pipeline currently does not allow running the model with batch sizes > 1 and instead runs each audio 1-by-1, thus being very suboptimal regarding speed. To use long-form transcription for batch size > 1, you can use the following snippet:
Docs have been

added to give examples for both short- and longform transcription:
But I don't think that this is enough for people to notice this method. We should in my opinion create much better guides for Whisper (will be done in a follow-up PR).
Credits
IMPORTANT: Most of the code added from this PR was copied & tweaked from the original whisper code: https://github.com/openai/whisper/blob/main/whisper/transcribe.py . Therefore 90% of the credit of this PR goes to @jongwook as the original author of the
transcribe.pycode.Why copy the code?!: We originally weren't planning on integrating the full long-form transcription algorithm to
transformersbut a couple of reasons forced us now to add it:Next steps:
When looking at all long-form generation strategies:

Transformers has now support for the following:
In a follow-up PR we will add: "temperature fallback", "voice activity detection", and "previous text conditioning".
Results:
Note: that for "chunked transformers" the numbers are
crossed-throughbecause the original results from the whisper paper seem to have been slightly incorrect. Re-running the eval gives better results.Here the results for
openai/whisper-tiny.en17.517.2824.123.6916.414.5517.417.46Here the results for
openai/whisper-small.en15.112.6120.616.318.77.0614.514.02Here the results for
openai/whisper-large-v211.811.815.115.06.35.213.613.5Update:
It seems like the number we measured in the distil-whisper paper for chunked long-form are a bit off. Re-running them gives the following:
Here the results for
openai/whisper-tiny.en=> New algo is on avg. 0.01 WER abs points worse which means it's identical
Here the results for
openai/whisper-small.en=> New algo is on avg. 0.005 WER abs points worse which means it's identical
Here the results for
openai/whisper-large-v2=> New algo is on avg. 0225 WER abs points better which means it's identical or (tiny tiny bit better)
Batch size > 1
The code now fully functions for batch size > 1 (made sure that results on the four datasets is within +/- 0.1 % WER). When using batch size = 8, there is a 4x speed-up for large-v2, 2x speed-up for small (and 1.5x speed-up for tiny). The bigger the model, the larger the speed-up!
One should definitely use larger batch sizes when doing long-form timestamp prediction!