Description
When we pass indices as dtype long to top_k_top_p_sampling_from_logits(), the function returns a wrong result. Ideally the function would reject the long inputs, or cast them automatically to int32.
Reproduce:
In the example below we pass indices [0, 0, 1, 1] meaning we want two samples from each logits row. The logits row are designed to always return sample 1 in the first row and 0 in the second row. So we expect samples to be [1, 1, 0, 0]. The incorrect samples returns instead [1, 1, 1, 1]
import torch
import flashinfer
torch.manual_seed(42)
top_p = 0.5
top_k = 3
logits = torch.tensor([[-1, 100], [100, -1]], device="cuda")
indices = torch.tensor([0, 0, 1, 1], device="cuda", dtype=torch.int32)
correct_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, top_k, top_p, indices=indices
)
print(correct_samples)
incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, top_k, top_p, indices=indices.long()
)
print(incorrect_samples)
Description
When we pass
indicesas dtypelongtotop_k_top_p_sampling_from_logits(), the function returns a wrong result. Ideally the function would reject the long inputs, or cast them automatically to int32.Reproduce:
In the example below we pass indices [0, 0, 1, 1] meaning we want two samples from each logits row. The logits row are designed to always return sample 1 in the first row and 0 in the second row. So we expect samples to be [1, 1, 0, 0]. The incorrect samples returns instead [1, 1, 1, 1]