1616#ifndef FLASHINFER_SAMPLING_CUH_
1717#define FLASHINFER_SAMPLING_CUH_
1818
19- #include < driver_types.h>
20-
2119#include < cub/block/block_adjacent_difference.cuh>
2220#include < cub/block/block_reduce.cuh>
2321#include < cub/block/block_scan.cuh>
@@ -347,13 +345,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
347345 }
348346 __syncthreads ();
349347 if (tx == 0 ) {
348+ output[bx] = sampled_id;
350349 if (temp_storage.data .block_aggregate .pair .count >= k) {
351350 // failed to sample within MAX_TOP_P_ROUNDS
352351 if (success != nullptr ) {
353352 success[bx] = false ;
354353 }
355354 } else {
356- output[bx] = sampled_id;
357355 if (success != nullptr ) {
358356 success[bx] = true ;
359357 }
@@ -433,13 +431,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
433431 }
434432 __syncthreads ();
435433 if (tx == 0 ) {
434+ output[bx] = sampled_id;
436435 if (float (q) >= top_p) {
437436 // failed to sample within MAX_TOP_P_ROUNDS
438437 if (success != nullptr ) {
439438 success[bx] = false ;
440439 }
441440 } else {
442- output[bx] = sampled_id;
443441 if (success != nullptr ) {
444442 success[bx] = true ;
445443 }
@@ -539,13 +537,13 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
539537 }
540538 __syncthreads ();
541539 if (tx == 0 ) {
540+ output[bx] = sampled_id;
542541 if (pivot < scaled_p) {
543542 // failed to sample within MAX_ROUNDS
544543 if (success != nullptr ) {
545544 success[bx] = false ;
546545 }
547546 } else {
548- output[bx] = sampled_id;
549547 if (success != nullptr ) {
550548 success[bx] = true ;
551549 }
@@ -627,13 +625,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
627625 }
628626 __syncthreads ();
629627 if (tx == 0 ) {
628+ output[bx] = sampled_id;
630629 if (temp_storage.data .block_aggregate .pair .count >= k || float (q) >= p) {
631630 // failed to sample within MAX_TOP_P_ROUNDS
632631 if (success != nullptr ) {
633632 success[bx] = false ;
634633 }
635634 } else {
636- output[bx] = sampled_id;
637635 if (success != nullptr ) {
638636 success[bx] = true ;
639637 }
@@ -808,7 +806,7 @@ struct RenormTempStorage {
808806template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
809807 typename DType>
810808__global__ void TopPRenormProbKernel (DType* probs, DType* renormed_prob, DType* top_p_arr,
811- float top_p_val, float eps, uint32_t d) {
809+ float top_p_val, uint32_t d) {
812810 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
813811 const uint32_t row_idx = bx;
814812 float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
@@ -844,12 +842,20 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
844842 threadlocal_max_val = temp_storage.data .max_val ;
845843
846844 float low = 0 , high = threadlocal_max_val;
845+ DType min_gt_low, max_le_high;
847846 DType sum_low (1 );
848- // f(x) = probs[probs > x], f(x) is non-increasing
849- // loop invariant: f(low) >= p, f(high) < p
850- while (high - low > eps) {
847+ // f(x) = sum(probs[probs > x]), f(x) is non-increasing
848+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
849+ // loop invariant:
850+ // - f(low) >= p, f(high) < p
851+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
852+ // stopping condition
853+ // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
854+ do {
851855 DType threadlocal_sum (0 );
852856 float mid = (low + high) / 2 ;
857+ min_gt_low = high;
858+ max_le_high = low;
853859 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
854860 probs_vec.fill (DType (0 ));
855861 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -858,26 +864,42 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
858864#pragma unroll
859865 for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
860866 probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType (0 );
867+ if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
868+ min_gt_low = min (min_gt_low, probs_vec[j]);
869+ }
870+ if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
871+ max_le_high = max (max_le_high, probs_vec[j]);
872+ }
861873 }
862874 threadlocal_sum +=
863875 BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
864876 .Sum <VEC_SIZE>(probs_greater_than_pivot);
865877 __syncthreads ();
866878 }
879+ min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
880+ .Reduce (min_gt_low, cub::Min ());
881+ __syncthreads ();
882+ max_le_high =
883+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
884+ .Reduce (max_le_high, cub::Max ());
867885 if (tx == 0 ) {
868886 temp_storage.data .block_aggregate .value = threadlocal_sum;
887+ temp_storage.data .min_val = min_gt_low;
888+ temp_storage.data .max_val = max_le_high;
869889 }
870890 __syncthreads ();
871891 threadlocal_sum = temp_storage.data .block_aggregate .value ;
892+ min_gt_low = temp_storage.data .min_val ;
893+ max_le_high = temp_storage.data .max_val ;
872894 if (threadlocal_sum >= p) {
873895 low = mid;
874896 sum_low = float (threadlocal_sum);
875897 } else {
876- high = mid;
898+ high = min ( mid, max_le_high) ;
877899 }
878- }
900+ } while (min_gt_low != max_le_high);
879901
880- DType normalizer = math::ptx_rcp (max (sum_low, eps ));
902+ DType normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
881903
882904 // normalize
883905 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
@@ -898,7 +920,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
898920template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
899921 typename DType, typename IdType>
900922__global__ void TopKMaskLogitsKernel (DType* logits, DType* masked_logits, IdType* top_k_arr,
901- uint32_t top_k_val, float eps, uint32_t d) {
923+ uint32_t top_k_val, uint32_t d) {
902924 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
903925 const uint32_t row_idx = bx;
904926 uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -941,12 +963,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
941963 threadlocal_min_val = temp_storage.data .min_val ;
942964
943965 float low = threadlocal_min_val - 1 , high = threadlocal_max_val;
966+ DType min_gt_low, max_le_high;
944967 // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
945- // loop invariant: f(low) >= k, f(high) < k
946- while (high - low > eps) {
968+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
969+ // loop invariant:
970+ // - f(low) >= k, f(high) < k
971+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
972+ // stopping condition: min_gt_low == max_le_high
973+ // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
974+ do {
947975 int threadlocal_count_sum = 0 ;
948976 int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
949977 float mid = (low + high) / 2 ;
978+ min_gt_low = high;
979+ max_le_high = low;
950980 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
951981 logits_vec.fill (DType (0 ));
952982 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -956,23 +986,41 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
956986 for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
957987 probs_greater_than_pivot_count[j] =
958988 logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
989+ if (logits_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
990+ min_gt_low = min (min_gt_low, logits_vec[j]);
991+ }
992+ if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
993+ max_le_high = max (max_le_high, logits_vec[j]);
994+ }
959995 }
960996 threadlocal_count_sum +=
961997 BlockReduce<int , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce_int )
962998 .Sum <VEC_SIZE>(probs_greater_than_pivot_count);
963999 __syncthreads ();
9641000 }
1001+ min_gt_low =
1002+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1003+ .Reduce (min_gt_low, cub::Min ());
1004+ __syncthreads ();
1005+ max_le_high =
1006+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1007+ .Reduce (max_le_high, cub::Max ());
1008+ __syncthreads ();
9651009 if (tx == 0 ) {
9661010 temp_storage.data .block_aggregate .count = threadlocal_count_sum;
1011+ temp_storage.data .min_val = min_gt_low;
1012+ temp_storage.data .max_val = max_le_high;
9671013 }
9681014 __syncthreads ();
9691015 threadlocal_count_sum = temp_storage.data .block_aggregate .count ;
1016+ min_gt_low = temp_storage.data .min_val ;
1017+ max_le_high = temp_storage.data .max_val ;
9701018 if (threadlocal_count_sum >= k) {
9711019 low = mid;
9721020 } else {
973- high = mid;
1021+ high = min ( mid, max_le_high) ;
9741022 }
975- }
1023+ } while (min_gt_low != max_le_high);
9761024 pivot = low;
9771025 }
9781026
@@ -996,7 +1044,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
9961044template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
9971045 typename DType, typename IdType>
9981046__global__ void TopKRenormProbKernel (DType* probs, DType* renormed_prob, IdType* top_k_arr,
999- uint32_t top_k_val, float eps, uint32_t d) {
1047+ uint32_t top_k_val, uint32_t d) {
10001048 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
10011049 const uint32_t row_idx = bx;
10021050 uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
@@ -1033,13 +1081,21 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10331081 threadlocal_max_val = temp_storage.data .max_val ;
10341082
10351083 float low = 0 , high = threadlocal_max_val;
1084+ DType min_gt_low, max_le_high;
10361085 DType sum_low (1 );
10371086 // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
1038- // loop invariant: f(low) >= k, f(high) < k
1039- while (high - low > eps) {
1087+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
1088+ // loop invariant:
1089+ // - f(low) >= k, f(high) < k
1090+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
1091+ // stopping condition: min_gt_low == max_le_high
1092+ // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
1093+ do {
10401094 Pair<DType> threadlocal_sum{DType (0 ), 0 };
10411095 Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
10421096 float mid = (low + high) / 2 ;
1097+ min_gt_low = high;
1098+ max_le_high = low;
10431099 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
10441100 probs_vec.fill (DType (0 ));
10451101 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
@@ -1050,26 +1106,44 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10501106 probs_greater_than_pivot_pair[j] = {
10511107 (probs_vec[j] > mid) ? probs_vec[j] : DType (0 ),
10521108 (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
1109+ if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1110+ min_gt_low = min (min_gt_low, probs_vec[j]);
1111+ }
1112+ if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
1113+ max_le_high = max (max_le_high, probs_vec[j]);
1114+ }
10531115 }
10541116 threadlocal_sum += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
10551117 temp_storage.block_prim .reduce_pair )
10561118 .Sum <VEC_SIZE>(probs_greater_than_pivot_pair);
10571119 __syncthreads ();
10581120 }
1121+ min_gt_low =
1122+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1123+ .Reduce (min_gt_low, cub::Min ());
1124+ __syncthreads ();
1125+ max_le_high =
1126+ BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
1127+ .Reduce (max_le_high, cub::Max ());
1128+ __syncthreads ();
10591129 if (tx == 0 ) {
10601130 temp_storage.data .block_aggregate .pair = threadlocal_sum;
1131+ temp_storage.data .min_val = min_gt_low;
1132+ temp_storage.data .max_val = max_le_high;
10611133 }
10621134 __syncthreads ();
10631135 threadlocal_sum = temp_storage.data .block_aggregate .pair ;
1136+ min_gt_low = temp_storage.data .min_val ;
1137+ max_le_high = temp_storage.data .max_val ;
10641138 if (threadlocal_sum.count >= k) {
10651139 low = mid;
10661140 sum_low = float (threadlocal_sum.value );
10671141 } else {
1068- high = mid;
1142+ high = min ( mid, max_le_high) ;
10691143 }
1070- }
1144+ } while (min_gt_low != max_le_high);
10711145
1072- normalizer = math::ptx_rcp (max (sum_low, eps ));
1146+ normalizer = math::ptx_rcp (max (sum_low, 1e-8 ));
10731147 pivot = low;
10741148 }
10751149
@@ -1090,7 +1164,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
10901164}
10911165
10921166template <typename DType>
1093- cudaError_t TopPRenormProb (DType* probs, DType* renormed_prob, DType* top_p_arr, float eps,
1167+ cudaError_t TopPRenormProb (DType* probs, DType* renormed_prob, DType* top_p_arr,
10941168 uint32_t batch_size, float top_p_val, uint32_t d,
10951169 cudaStream_t stream = 0 ) {
10961170 const uint32_t BLOCK_THREADS = 1024 ;
@@ -1099,7 +1173,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
10991173 const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11001174 dim3 nblks (batch_size);
11011175 dim3 nthrs (BLOCK_THREADS);
1102- void * args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, & d};
1176+ void * args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
11031177 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
11041178 auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
11051179 FLASHINFER_CUDA_CALL (
@@ -1110,7 +1184,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
11101184}
11111185
11121186template <typename DType, typename IdType>
1113- cudaError_t TopKRenormProb (DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps,
1187+ cudaError_t TopKRenormProb (DType* probs, DType* renormed_prob, IdType* top_k_arr,
11141188 uint32_t batch_size, uint32_t top_k_val, uint32_t d,
11151189 cudaStream_t stream = 0 ) {
11161190 const uint32_t BLOCK_THREADS = 1024 ;
@@ -1119,7 +1193,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
11191193 const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11201194 dim3 nblks (batch_size);
11211195 dim3 nthrs (BLOCK_THREADS);
1122- void * args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, & d};
1196+ void * args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
11231197 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
11241198 auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
11251199 FLASHINFER_CUDA_CALL (
@@ -1130,7 +1204,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr
11301204}
11311205
11321206template <typename DType, typename IdType>
1133- cudaError_t TopKMaskLogits (DType* logits, DType* masked_logits, IdType* top_k_arr, float eps,
1207+ cudaError_t TopKMaskLogits (DType* logits, DType* masked_logits, IdType* top_k_arr,
11341208 uint32_t batch_size, uint32_t top_k_val, uint32_t d,
11351209 cudaStream_t stream = 0 ) {
11361210 const uint32_t BLOCK_THREADS = 1024 ;
@@ -1139,7 +1213,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar
11391213 const uint32_t smem_size = sizeof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
11401214 dim3 nblks (batch_size);
11411215 dim3 nthrs (BLOCK_THREADS);
1142- void * args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, & d};
1216+ void * args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
11431217 DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
11441218 auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
11451219 FLASHINFER_CUDA_CALL (
0 commit comments