Skip to content

feat: Batch-size invariant FA2 Prefill & Decode#1675

Merged
yzh119 merged 22 commits intoflashinfer-ai:mainfrom
Edenzzzz:invariant_prefill
Sep 15, 2025
Merged

feat: Batch-size invariant FA2 Prefill & Decode#1675
yzh119 merged 22 commits intoflashinfer-ai:mainfrom
Edenzzzz:invariant_prefill

Conversation

@Edenzzzz
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz commented Sep 11, 2025

📌 Description

As mentioned in https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/, split-kv can cause non-determinism. This PR introduces a fixed split size to produce exact precision when batch size changes.
Compatibility with CUDA graph is left for future work, as varied seqlen can cause #CTAs to exceed the number when the graph was captured.
Also adds a disable_split_kv flag for CUDA Graph mode.

cc @Fridge003 I will likely be busy with some papers this semester, but this should be a useful starting point
image

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Edenzzzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the FlashInfer library by introducing a mechanism to achieve deterministic and batch-invariant behavior for prefill and decode operations. It allows users to specify a fixed split size for processing KV caches, which helps in producing consistent results regardless of batch size, thereby mitigating non-determinism in LLM inference. The changes involve modifications across the C++ backend, Python bindings, and the core attention scheduler logic, along with the addition of a dedicated test to validate the new functionality.

Highlights

  • Batch-Invariant Prefill and Decode: Introduces a new fixed_split_size parameter to the prefill and decode operations, allowing for deterministic and batch-size invariant outputs by fixing the split size for KV cache processing. This addresses non-determinism issues in LLM inference.
  • CUDA Graph Compatibility: Explicitly notes that compatibility with CUDA graphs is not guaranteed when fixed_split_size is enabled, as varying sequence lengths can still lead to a different number of launched CTAs (Cooperative Thread Arrays), even with a fixed batch size.
  • New Test Case: Adds a new test file, tests/test_invariant_batch_decode.py, to verify the batch-invariant behavior of the decode operation when fixed_split_size is utilized.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fixed_split_size parameter to enable batch-invariant attention computation in prefill and tensor-core decode paths, addressing potential non-determinism. The changes are well-implemented across the C++ backend, Python wrappers, and TVM bindings. The addition of a new test to verify batch-invariance for the decode path is a great step. My review includes a few suggestions to improve documentation consistency and test coverage.

Comment thread flashinfer/decode.py
Comment thread flashinfer/prefill.py Outdated
Comment thread tests/test_invariant_batch_decode.py Outdated
Comment on lines +59 to +157
def test_batch_decode_tensor_cores(
batch_size: int,
invariant_bs: int,
kv_len: int,
fixed_split_size: int,
page_size: int,
num_kv_heads: int,
group_size: int,
head_dim: int,
kv_layout: str,
pos_encoding_mode: str,
):
num_qo_heads = num_kv_heads * group_size
q = torch.randn(
batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16
)
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_data = (
torch.randn(
total_num_pages,
2,
num_kv_heads,
page_size,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
/ 10
if kv_layout == "HND"
else torch.randn(
total_num_pages,
2,
page_size,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
/ 10
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32)
* num_pages_per_seq
)
kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
kv_last_page_len = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0"
)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0")

wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True
)
wrapper_tensor_cores.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
fixed_split_size=fixed_split_size,
)
o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run(
q, kv_data, return_lse=True
)

kv_indptr_invariant = kv_indptr[: invariant_bs + 1]
kv_last_page_len_invariant = kv_last_page_len[:invariant_bs]
wrapper_tensor_cores.plan(
kv_indptr_invariant,
kv_indices,
kv_last_page_len_invariant,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode=pos_encoding_mode,
data_type=torch.float16,
q_data_type=torch.float16,
fixed_split_size=fixed_split_size,
)
o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run(
q[:invariant_bs], kv_data, return_lse=True
)
torch.testing.assert_close(
o_tensor_cores[:invariant_bs], o_tensor_cores_invariant, rtol=1e-7, atol=1e-7
)
torch.testing.assert_close(
lse_tensor_cores[:invariant_bs],
lse_tensor_cores_invariant,
rtol=1e-7,
atol=1e-7,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new test is great for verifying the batch-invariance of the decode path with tensor cores. Since the PR also adds fixed_split_size to the prefill path (BatchPrefillWithPagedKVCacheWrapper), it would be beneficial to add a similar test for prefill to ensure its batch-invariance as well. The current test only covers the decode case where query length is 1 for each request, while prefill handles variable query lengths.

Comment thread tests/test_invariant_batch_decode.py Outdated
# test that without fixed split size, precision is different
# TODO: this works for the first 29 cases, but then fails with "illegal memory access"..?

# wrapper_tensor_cores.plan(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to debug this if anyone has time
image

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Sep 11, 2025

cc @yzh119

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Sep 11, 2025

The easiest way to turn off split-k in flashinfer is to change these two lines to std::numeric_limits<int>::max():

  1. std::max(128 / page_size, 1U));
  2. const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U);

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Yes, but simply turning it off would hurt performance

Comment thread tests/test_invariant_batch_decode.py Outdated
o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run(
q[:invariant_bs], kv_data, return_lse=True
)
torch.testing.assert_close(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add utility functions to compare bitwise identical check?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to torch.equal, should do the job
image

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Sep 11, 2025

One minor note, the attention merge kernel (

/*!
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of
* index sets at each position might vary.
*
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.
*
* \tparam vec_size The vector size used in the kernel.
* \tparam bdx The blockDim.x used in the kernel.
* \tparam bdy The blockDim.y used in the kernel.
* \tparam num_smem_stages The number of stages of shared memory used in the kernel.
* \tparam DTypeIn The data type of v.
* \tparam DTypeO The data type of v_merged.
* \param V The partial v of index sets. (nnz, h, d)
* \param S The logsumexp value of index sets. (nnz, h)
* \param indptr The start offsets of each position in the variable length array.
* \param v_merged The merged v of index sets union. (n, h, d)
* \param s_merged The merged logsumexp value of index sets union. (n, h)
* \param max_seq_len The maximum sequence length supported by the kernel.
* \param seq_len_ptr The current sequence length (number of positions populated in indptr).
* \param num_heads The number of heads of v.
* \param head_dim The dimension of each head.
* \note s are logsumexp values with base 2.
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
typename DTypeO, typename IdType>
__global__ void PersistentVariableLengthMergeStatesKernel(
DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr,
uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;
extern __shared__ uint8_t smem[];
DTypeIn* v_smem = (DTypeIn*)smem;
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
#pragma unroll 1
for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) {
// NOTE (Yilong): necessary to prevent hazard on smaller `num_index_sets`
__syncthreads();
uint32_t pos = i / num_heads;
uint32_t head_idx = i % num_heads;
state_t<vec_size> st;
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
if (num_index_sets == 0) {
vec_t<DTypeO, vec_size> v;
v.fill(DTypeO(0.f));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -math::inf;
}
continue;
}
if (num_index_sets == 1) {
vec_t<DTypeO, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
continue;
}
#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] =
iter * bdy + (ty * bdx + tx) < num_index_sets
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
: 0.f;
__syncthreads();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncthreads();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
__syncthreads();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V +
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<0>();
__syncthreads();
st.normalize();
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
st.normalize();
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
) uses parallel reduction, I'm not sure if we need to change it to sequential reduction to guarantee full reproducibility.

Current reduction order:

  1. thread 0-8: out[0] = merge(0, 16, 32, ...)
  2. thread 9-16: out[1] = merge(1, 17, 33, ...)
  3. ...
  4. thread 120-127: out[15] = merge( 15, 31, 47, ...)

and at here

threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);

we reduce them into:
merge(out[0], out[1], ... out[15])

the order might change for different kv-length.

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Sep 12, 2025

Though I thought the threads in a block just collectively reduce along the head dim, one position & head at a time in threadblock_sync_state? No matter which head/seq_id it picks up.
As long as changing one request's kv len doesn't affect other requests, it should be fine

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Sep 12, 2025

@happierpig has done some work on changing the reduction order before.

As long as changing one request's kv len doesn't affect other requests, it should be fine

It matters if you want to guarantee the reproducibility across prefill and decode, my question is what kind of reproducibility do you want to achieve in this PR?

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Sep 12, 2025

It matters if you want to guarantee the reproducibility across prefill and decode, my question is what kind of reproducibility do you want to achieve in this PR?

We want to ensure that changing the batch size(adding requests) does not change the output of individual requests, using one kernel. Not sure how reduction affects that? They both use the same reduction and batch prefill kernel

@Edenzzzz Edenzzzz changed the title feat: Batch-invariant FA2 Prefill & Decode feat: Batch-size invariant FA2 Prefill & Decode Sep 12, 2025
@happierpig
Copy link
Copy Markdown
Collaborator

happierpig commented Sep 12, 2025

the order might change for different kv-length.

I think the current merge_states should be fine for the batch-invariant reproducibility (which IMO means a single request's attention is exactly the same result no matter how many requests are batched together). As the parallel reduction in merge states is deterministic given an arbitrary kv-len, the order will be the same for an arbitrary kv-len, therefore the results.

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Sep 12, 2025

One more note, in the case of CUDA graph, we can binary-search a split size that just launches below max_batch_size_if_split CTAs, when the provided one would launch too many CTAs. This can be done in a follow up PR
It's better to disable split-kv for cuda graph, as if you split too much and hit that binary search, chunk size is again non-deterministic

@Fridge003
Copy link
Copy Markdown

Do we need to add a flag for disabling split-kv? Since finally we want to use cuda graph for decoding.

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Do we need to add a flag for disabling split-kv? Since finally we want to use cuda graph for decoding.

It should work now

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Sep 13, 2025

To be more clear on my comments about "prefill and decode consistency":

I'm worried about the use cases where we manually merge attention outputs from different KV components (e.g. in chunked-prefill, speculative decoding), it's okay to ignore these cases at this moment but we should know their possible effect on reprodubility.

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Sep 13, 2025

I'm worried about the use cases where we manually merge attention outputs from different KV components (e.g. in chunked-prefill, speculative decoding), it's okay to ignore these cases at this moment but we should know their possible effect on reprodubility.

Yes, I think if the chunk size changes, then setting a fixed split size can still lead to differences in the tail chunk for each request.

  • Two issues:
  1. sgl uses a new_token_ratio which increases when decode hits OOM and decreases when it successfully runs. Different chunk sizes can lead to kv len being truncated in the tail chunk.
    For example, assume chunk size is 4096 and kv split size is 2048, the first req is computed using split-kv. However, if the chunk size then drops to 2048, it will not use split kv.
  2. (Should be fine as FCFS is used by default) If you use Longest Prefix Match instead of FCFS, the order in which the scheduler fills up that chunk size is non-deterministic (truncate the last request).
  • As for speculative decoding, the varied # of tokens to verify might also lead to non-deterministic tail (though not sure if ppl ever use spec dec in RL)

In these cases we need to consistently use one BatchPagedPrefill kernel, and avoid merging different chunks using a separate merge_state kernel.
Radix cache shouldn't cause problems as long as you save kv cache first and only use the paged kernel.

cc @Fridge003 @merrymercy @zhyncs

@Fridge003
Copy link
Copy Markdown

I'm worried about the use cases where we manually merge attention outputs from different KV components (e.g. in chunked-prefill, speculative decoding), it's okay to ignore these cases at this moment but we should know their possible effect on reprodubility.

Yes, I think if the chunk size changes, then setting a fixed split size can still lead to differences in the tail chunk for each request.

Two issues:

  1. sgl uses a new_token_ratio which increases when decode hits OOM and decreases when it successfully runs. Different chunk sizes can lead to kv len being truncated in the tail chunk.

  2. (Should be fine as FCFS is used by default) If you use Longest Prefix Match instead of FCFS, the order in which the scheduler fills up that chunk size is non-deterministic (truncate the last request).

As for speculative decoding, the varied # of tokens to verify can also lead to non-deterministic tail (though not sure if ppl ever use spec dec in RL)

In these cases we have to disable split kv.

So we will need to stick to FCFS and avoid changing new_token_ratio.

cc @Fridge003 @merrymercy @zhyncs

For deterministic inference we can add some assumptions, like disabling radix cache, disabling speculative decoding, or even disabling chunked prefill. Part of the speed can be sacrificed for more stable output

@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Prefill tests also pass
image

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, left some minor suggestions.

Comment thread csrc/batch_prefill.cu Outdated
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal, int64_t window_left) {
int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size = -1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't encourage setting up default value in C++ APIs (as these APIs are directly exposed as python bindings), any use case of them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just example default values for python interface writers, removing now

IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal, int64_t window_left, TVMStreamHandle cuda_stream) {
int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TVM bindings, @MasterJH5574 could probably help with checking them.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I just notice this change. On TVM side we don't use fixed_split_size at this moment, so will need to remove it from the parameter list and use a default value. We will submit a fix later.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in #1680

@yzh119 yzh119 merged commit b159470 into flashinfer-ai:main Sep 15, 2025
2 checks passed
@Edenzzzz Edenzzzz deleted the invariant_prefill branch September 15, 2025 13:25
MasterJH5574 added a commit to MasterJH5574/flashinfer that referenced this pull request Sep 15, 2025
This PR uses the default value -1 for `fixed_split_size` introduced in
PR flashinfer-ai#1675, to keep the interface consistent with the TVM side.
MasterJH5574 added a commit that referenced this pull request Sep 15, 2025
## 📌 Description

This PR uses the default value -1 for `fixed_split_size` introduced in
PR #1675, to keep the interface consistent with the TVM side.

## 🔍 Related Issues

N/A

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants