Skip to content

[WIP] Word level timestamp for long-form generation#28984

Closed
patrickvonplaten wants to merge 2 commits intomainfrom
add_word_level_timestamp_long
Closed

[WIP] Word level timestamp for long-form generation#28984
patrickvonplaten wants to merge 2 commits intomainfrom
add_word_level_timestamp_long

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Feb 12, 2024

What does this PR do?

Fixes: #28977

We haven't added word level timestamp for long-form generation yet. It's definitely possible, but it'll require some more changes in generate. Happy to take a closer look here the next days.

With the PR in its current state, one can retrieve word level timestamps, but they are not correct because the_postprocess_outputs is not correct. Test it with:

#!/usr/bin/env python3
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa

DEVICE = "cuda"

model_id = "openai/whisper-tiny"

processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
model.to(DEVICE)

audio, _ = librosa.load("./common_voice_fr_17299386.mp3", sr=16_000)

inputs = processor(audio,
                           sampling_rate=16_000,
                           return_tensors="pt",
                           truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
                           return_attention_mask=True,
                           padding="longest")

input_features = inputs.to(DEVICE, dtype=torch.float16)
inputs["input_features"] = inputs.input_features.repeat(1, 1, 8)
print(inputs.input_features.shape)

outputs = model.generate(**input_features, return_token_timestamps=True, return_segments=True)

# decode token ids to text
transcription = processor.batch_decode(outputs["sequences"], skip_special_tokens=False)

print(transcription[0])
per_segment_word_timestamps = [segment["result"]["token_timestamps"] for segment in outputs["segments"][0]]
all_word_timestamps = [x + y["start"] for x, y in zip(per_segment_word_timestamps, outputs["segments"][0])]

print("Word level timestamps", all_word_timestamps)

@patrickvonplaten patrickvonplaten marked this pull request as draft February 12, 2024 19:35
@patrickvonplaten patrickvonplaten changed the title Word level timestamp for long-form generation [WIP] Word level timestamp for long-form generation Feb 12, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@patrickvonplaten
Copy link
Contributor Author

It's actually much harder to do this than I thought and I sadly won't have time to finish this PR, so I'll leave it in this form.

We're facing the following problematic here.

  1. It's actually kind of tricky to run the _extract_timestamps funtion for the whole batch when doing long-form => after some thought it's better to run this function for every batch index. This should be changed and would then also make the tricky cross_attention re-ordering easier / redundant
  2. We need to split the cross attention both by input and output length. Essentially the output length is defined by each individual segment and the input length by the start and end timestamps that are passed. This should be done in the _extract_timestamps function.

If anybody in the community is willing to give this PR a try, feel free to use any/all my code.

cc @sanchit-gandhi as well

@zucchini-nlp
Copy link
Member

I will be taking over this issue, since I found that no-one else is working on it.

@github-actions
Copy link
Contributor

github-actions bot commented May 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Whisper Sequential long-form decoding doesn't work with timestamps per token

3 participants