Skip to content

Incorrect scores returned in Whisper with num_beams>1 #32246

@cifkao

Description

@cifkao

TL;DR: Scores corresponding to the wrong sequence in the batch/beam are returned.

System Info

  • transformers version: 4.43.2
  • Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.31
  • Python version: 3.9.18
  • Huggingface_hub version: 0.24.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@sanchit-gandhi @ylacombe @patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import Audio, load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import numpy as np

model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-tiny", torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model.cuda()

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[0]["audio"]["array"].astype(np.float32)
inputs = processor(
    [audio],
    return_tensors="pt",
    truncation=False,
    padding="longest",
    sampling_rate=16_000,
)
inputs = inputs.to(model.device, torch.float16)

generation_output = model.generate(
    **inputs,
    language="en",
    return_timestamps=True,
    return_segments=True,
    output_scores=True,
    num_beams=2,
    temperature=0.0,
    logprob_threshold=0.0,
    compression_ratio_threshold=2.4,
    no_speech_threshold=0.6,
)

# Print each token along with its log-probability and beam index
segment = generation_output["segments"][0][0]
tokens = segment["result"]["sequences"]
scores = segment["result"]["scores"]
beam_indices = segment["result"]["beam_indices"]
logprobs = torch.as_tensor([s.log_softmax(-1)[t] for s, t in zip(scores, segment["tokens"])])
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s.item(), b.item()) for s, t, b in zip(logprobs, tokens, beam_indices)], sep="\n")
('<|0.00|>', -0.061553955078125, 0)
(' Folks', -1.9033203125, 0)
(',', -0.406005859375, 0)
(' if', -0.038330078125, 0)
(' you', -0.0019512176513671875, 0)
(' watch', -0.1451416015625, 0)
(' the', -0.1986083984375, 0)
(' show', -0.0019512176513671875, 0)
(',', -0.28076171875, 0)
(' you', -0.291015625, 0)
(' know', -0.026031494140625, 0)
(' I', -1.05859375, 0)
(' spent', -10.125, 1)
(' a', -12.5625, 1)
(' lot', -8.1796875, 1)
(' of', -11.8359375, 1)
(' time', -5.4296875, 1)
(' right', -14.109375, 1)
(' over', -inf, 1)
(' there', -0.0307769775390625, 0)
('.', -0.236328125, 0)
('<|4.00|>', -3.7109375, 0)
# Now run a forward pass with the generated tokens
inputs_forward = {k: v[..., :3000].cuda() for k, v in inputs.items()}
inputs_forward["decoder_input_ids"] = torch.cat(
    [
        torch.as_tensor(processor.tokenizer.encode("<|startoftranscript|><|en|><|transcribe|>", add_special_tokens=False)),
        tokens,
    ],
)[None].cuda()

with torch.inference_mode():
    output_forward = model(**inputs_forward)

# Print each token along with its log-probability
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s[t].item()) for s, t in zip(output_forward.logits.squeeze(0).log_softmax(-1), inputs_forward["decoder_input_ids"].squeeze(0)[1:])], sep="\n")
('<|en|>', -0.3857421875)
('<|transcribe|>', -6.556510925292969e-06)
('<|0.00|>', -0.1917724609375)
(' Folks', -1.939453125)
(',', -0.40966796875)
(' if', -0.038818359375)
(' you', -0.0020503997802734375)
(' watch', -0.1458740234375)
(' the', -0.204345703125)
(' show', -0.00235748291015625)
(',', -0.276123046875)
(' you', -0.299560546875)
(' know', -0.0259552001953125)
(' I', -1.06640625)
(' spent', -0.5)
(' a', -0.0234832763671875)
(' lot', -0.023773193359375)
(' of', -0.0233001708984375)
(' time', -0.01520538330078125)
(' right', -1.12109375)
(' over', -0.0208892822265625)
(' there', -0.0306854248046875)
('.', -0.2406005859375)
('<|4.00|>', -3.798828125)
...

We can see that the scores returned by generate() are similar (though not identical) when the beam index is 0, but are much lower, and even -inf, when the beam index is 1, suggesting that we are getting scores from the wrong sequence in the beam. (I guess a small difference in the scores in the vicinity of timestamps might be explained by the logits processor, but the score of the generated token should clearly never be -inf.)

The bug seems to be here in _postprocess_outputs. This works fine with num_beams==1, but with num_beams>1, the shape of the items in seek_outputs["scores"] will be [num_beams * batch_size, vocab_size], while the code expects it to be [batch_size, vocab_size]. Therefore, instead of choosing the correct sequence in the beam/batch, this code will incorrectly combine scores from different sequences.

Expected behavior

The scores returned from generate() should be the same as in the forward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions