Skip to content

lookahead: fix n_seq_max and kv_unified configuration#18730

Closed
pestopoppa wants to merge 1 commit intoggml-org:masterfrom
pestopoppa:fix-lookahead-n-seq-max
Closed

lookahead: fix n_seq_max and kv_unified configuration#18730
pestopoppa wants to merge 1 commit intoggml-org:masterfrom
pestopoppa:fix-lookahead-n-seq-max

Conversation

@pestopoppa
Copy link
Contributor

Summary

Fixes llama-lookahead configuration issues that have been broken since PR #14482 (July 2025).

Note: This PR depends on #18729 for the batch init fix. Both PRs are needed for lookahead to fully work.

Root Cause

Two lookahead-specific configuration issues:

1. Sequence count (n_seq_max)

PR #14482 changed seq_id validation from LLAMA_MAX_SEQ (large constant) to n_seq_max (context-specific). Lookahead needs W + G + 1 = 31 sequences for parallel Jacobi decoding, but params.n_parallel defaulted to 1.

2. KV unified mode

Batch splitting with "coupled sequences" requires unified KV cache. Lookahead didn't enable this, causing:

split_equal: sequential split is not supported when there are coupled sequences

Fix

// lookahead requires W + G + 1 sequences for parallel Jacobi decoding
params.n_parallel = W + G + 1;

// unified KV cache is required for coupled sequences in batch splitting
params.kv_unified = true;

Bug Timeline

Date PR Effect
Nov 2023 #4207 lookahead.cpp created - works
July 2025 #14482 seq_id validation changed - breaks lookahead

Testing

With both this PR and #18729 applied:

encoded    4 tokens in    0.138 seconds
decoded   51 tokens in   71.591 seconds
n_accept  = 13

Dependencies


Bug history researched with Claude.

llama-lookahead has been broken since PR ggml-org#14482 (July 2025) which changed
seq_id validation from LLAMA_MAX_SEQ constant to context-specific n_seq_max.

Two lookahead-specific issues:

1. n_seq_max: Lookahead needs W + G + 1 = 31 sequences for parallel Jacobi
   decoding, but params.n_parallel defaulted to 1.
   Fix: Set params.n_parallel = W + G + 1 before context creation.

2. KV unified: Batch splitting with coupled sequences requires unified KV
   cache mode, but lookahead didn't enable it.
   Fix: Set params.kv_unified = true.

Bug timeline:
- Nov 2023: lookahead.cpp created, worked with LLAMA_MAX_SEQ constant
- July 2025: PR ggml-org#14482 changed to n_seq_max validation, broke lookahead

Note: This PR depends on ggml-org#18729 for the batch init fix (params.n_ctx ->
llama_n_ctx). Both PRs are needed for lookahead to fully work.

Tested with Qwen2.5-Coder-0.5B: lookahead generates output with n_accept > 0.

Bug history researched with Claude.
@ngxson
Copy link
Collaborator

ngxson commented Jan 10, 2026

kv_unified is only handled by server, no use to add it here. please stop sending us low-quality PRs

@ngxson ngxson closed this Jan 10, 2026
@ngxson
Copy link
Collaborator

ngxson commented Jan 10, 2026

the only thing that may need to make it work is to correctly specify params.n_parallel, you may merge your changes into one PR instead of spamming us with smaller one

@ggerganov
Copy link
Member

Actually, I think the unified KV is required for this example. Haven't ran it in a while, but it was developed when we only had unified cache. And probably forgot to update it when we switched to non-unified being the default.

@ngxson
Copy link
Collaborator

ngxson commented Jan 10, 2026

hmm ok I see, I thought this is the speculative decoding example. it seems like the lookahead example rely extensively on "branching" the generation via parallel seq.

It's a valid use of unified cache here then, sorry for the oversight. @pestopoppa it's better if you can merge all changes related to lookahead into one PR, this way we can test it easier.

pestopoppa added a commit to pestopoppa/llama.cpp that referenced this pull request Jan 10, 2026
Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
@pestopoppa
Copy link
Contributor Author

@pestopoppa it's better if you can merge all changes related to lookahead into one PR, this way we can test it easier.

Thanks for taking another look and appreciatign the nuance. I've consolidated the fix from this PR into #18729 (commit b917cd2).

The PR now includes:

  1. llama_n_ctx(ctx) fix for batch init (both lookup and lookahead)
  2. params.n_parallel = W + G + 1 for the 31 sequences
  3. params.kv_unified = true for coupled sequence batch splitting

Ready for testing when you have a chance.

ggerganov pushed a commit that referenced this pull request Jan 30, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR #16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR #4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR #4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR #10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR #16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR #16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR #14482 changed validation logic.

Consolidates fix from PR #18730 per maintainer request.

Commit message drafted with Claude.
4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants