Skip to content

Silent failure in top_k_top_p_sampling_from_logits() when indices are long #2115

@tomasruizt

Description

@tomasruizt

Description

When we pass indices as dtype long to top_k_top_p_sampling_from_logits(), the function returns a wrong result. Ideally the function would reject the long inputs, or cast them automatically to int32.

Reproduce:

In the example below we pass indices [0, 0, 1, 1] meaning we want two samples from each logits row. The logits row are designed to always return sample 1 in the first row and 0 in the second row. So we expect samples to be [1, 1, 0, 0]. The incorrect samples returns instead [1, 1, 1, 1]

import torch
import flashinfer

torch.manual_seed(42)
top_p = 0.5
top_k = 3
logits = torch.tensor([[-1, 100], [100, -1]], device="cuda")
indices = torch.tensor([0, 0, 1, 1], device="cuda", dtype=torch.int32)

correct_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
    logits, top_k, top_p, indices=indices
)
print(correct_samples)

incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
    logits, top_k, top_p, indices=indices.long()
)
print(incorrect_samples)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions