Skip to content

lookup, lookahead: fix crash when n_ctx not specified#18729

Merged
ggerganov merged 2 commits intoggml-org:masterfrom
pestopoppa:fix-lookup-lookahead-batch-init
Jan 30, 2026
Merged

lookup, lookahead: fix crash when n_ctx not specified#18729
ggerganov merged 2 commits intoggml-org:masterfrom
pestopoppa:fix-lookup-lookahead-batch-init

Conversation

@pestopoppa
Copy link
Contributor

Summary

Fixes a crash in llama-lookup and llama-lookahead when run without explicit -c flag:

GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root Cause

Both examples use params.n_ctx directly for batch initialization:

// lookup.cpp:109
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);

// lookahead.cpp:118
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);

Since #16653 changed the default n_ctx to 0 (for GPU auto-fitting), params.n_ctx remains 0 even after the context is properly initialized. This creates a zero-sized batch that crashes on the first common_batch_add().

Bug History

This bug was dormant for 2+ years:

Date PR Default n_ctx Effect
Nov 2023 #4207 512 lookahead.cpp created - works
Dec 2023 #4484 512 lookup.cpp created - works
Nov 2024 #10136 4096 Default increased - works
Dec 2025 #16653 0 Auto-fitting enabled - CRASHES

The pattern was always incorrect, but only triggered when n_ctx default became 0.

Fix

Use llama_n_ctx(ctx) to get the actual runtime context size:

// Before
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);

// After
llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1);

This matches:

  • The pattern already used in lookup.cpp:72 for max_context_size
  • The pattern used in speculative.cpp and speculative-simple.cpp

Testing

# Before fix (crashes):
llama-lookup -m model.gguf -f prompt.txt --draft 4 -n 50
# GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

# After fix (works):
llama-lookup -m model.gguf -f prompt.txt --draft 4 -n 50
# n_accept = 1, accept = 12.500%

Note

llama-lookahead has a separate pre-existing issue with sequence initialization (n_seq_max=1 when it needs W+G+1) that is unrelated to this batch size fix.

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.
pestopoppa added a commit to pestopoppa/llama.cpp that referenced this pull request Jan 10, 2026
llama-lookahead has been broken since PR ggml-org#14482 (July 2025) which changed
seq_id validation from LLAMA_MAX_SEQ constant to context-specific n_seq_max.

Two lookahead-specific issues:

1. n_seq_max: Lookahead needs W + G + 1 = 31 sequences for parallel Jacobi
   decoding, but params.n_parallel defaulted to 1.
   Fix: Set params.n_parallel = W + G + 1 before context creation.

2. KV unified: Batch splitting with coupled sequences requires unified KV
   cache mode, but lookahead didn't enable it.
   Fix: Set params.kv_unified = true.

Bug timeline:
- Nov 2023: lookahead.cpp created, worked with LLAMA_MAX_SEQ constant
- July 2025: PR ggml-org#14482 changed to n_seq_max validation, broke lookahead

Note: This PR depends on ggml-org#18729 for the batch init fix (params.n_ctx ->
llama_n_ctx). Both PRs are needed for lookahead to fully work.

Tested with Qwen2.5-Coder-0.5B: lookahead generates output with n_accept > 0.

Bug history researched with Claude.
pestopoppa added a commit to pestopoppa/llama.cpp that referenced this pull request Jan 10, 2026
llama-lookahead has been broken since PR ggml-org#14482 (July 2025) which changed
seq_id validation from LLAMA_MAX_SEQ constant to context-specific n_seq_max.

Two lookahead-specific issues:

1. n_seq_max: Lookahead needs W + G + 1 = 31 sequences for parallel Jacobi
   decoding, but params.n_parallel defaulted to 1.
   Fix: Set params.n_parallel = W + G + 1 before context creation.

2. KV unified: Batch splitting with coupled sequences requires unified KV
   cache mode, but lookahead didn't enable it.
   Fix: Set params.kv_unified = true.

Bug timeline:
- Nov 2023: lookahead.cpp created, worked with LLAMA_MAX_SEQ constant
- July 2025: PR ggml-org#14482 changed to n_seq_max validation, broke lookahead

Note: This PR depends on ggml-org#18729 for the batch init fix (params.n_ctx ->
llama_n_ctx). Both PRs are needed for lookahead to fully work.

Tested with Qwen2.5-Coder-0.5B: lookahead generates output with n_accept > 0.

Bug history researched with Claude.
Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
@pestopoppa
Copy link
Contributor Author

@ngxson Per your request, I've consolidated the fix from PR #18730 into this PR.

New commit: b917cd2 - adds sequence configuration for lookahead:

  • params.n_parallel = W + G + 1 (31 sequences for Jacobi decoding)
  • params.kv_unified = true (required for coupled sequences)

This PR now contains all lookahead-related fixes and is ready for review.


Comment drafted with Claude.

@ggerganov ggerganov merged commit 1488339 into ggml-org:master Jan 30, 2026
75 of 76 checks passed
4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants