Skip to content

llama : adjust default context size + print warnings#10136

Merged
ggerganov merged 2 commits intomasterfrom
gg/default-ctx
Nov 2, 2024
Merged

llama : adjust default context size + print warnings#10136
ggerganov merged 2 commits intomasterfrom
gg/default-ctx

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Nov 2, 2024

fix #8817, #9563 (comment)

By default, the examples will use a context size of 4096, instead of the training context of the model. In a lot of cases, the default training context can be very big - 32k to 128k tokens, which causes enormous KV cache allocation and failures for regular hardware.

Also, add warning logs when the specified context size per sequence does not match the training context.

@ggerganov ggerganov requested review from ngxson and slaren November 2, 2024 10:39
Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This should prevent me from burning my swapfile whenever I forget to specify -c

Tested and it shows the log too:

> ./llama-cli -m ../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -cnv -p "You are a helpful assistant"
...
llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
...

@github-actions github-actions bot added the devops improvements to build systems and github actions label Nov 2, 2024
@ggerganov
Copy link
Member Author

Is 4096 a good value, or should we go lower?

@ngxson
Copy link
Collaborator

ngxson commented Nov 2, 2024

According to HF hub statistics, the most used model nowadays is the llama 3 (3.1, 3.2) 8B

With a context size of 4096, the KV takes around 512MB which I think is a very reasonable amount.

@ggerganov ggerganov merged commit 1926d6e into master Nov 2, 2024
@ggerganov ggerganov deleted the gg/default-ctx branch November 2, 2024 13:18
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* llama : adjust default context size + print warnings

ggml-ci

* ggml-ci : add missing gpu-layers + adjust context sizes
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* llama : adjust default context size + print warnings

ggml-ci

* ggml-ci : add missing gpu-layers + adjust context sizes
pestopoppa added a commit to pestopoppa/llama.cpp that referenced this pull request Jan 10, 2026
Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.
pestopoppa added a commit to pestopoppa/llama.cpp that referenced this pull request Jan 10, 2026
Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.
ggerganov pushed a commit that referenced this pull request Jan 30, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR #16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR #4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR #4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR #10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR #16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR #16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR #14482 changed validation logic.

Consolidates fix from PR #18730 per maintainer request.

Commit message drafted with Claude.
4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* lookup, lookahead: fix crash when n_ctx not specified

Since PR ggml-org#16653 (Dec 15, 2025), the default n_ctx is 0 to enable automatic
GPU memory fitting. This causes llama-lookup and llama-lookahead to crash
when run without explicit -c flag:

    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded")

Root cause: Both examples use params.n_ctx directly for batch initialization,
but params.n_ctx remains 0 even after the context is properly initialized
to n_ctx_train internally.

Bug history:
- Nov 2023: lookahead.cpp created (PR ggml-org#4207) with params.n_ctx pattern
- Dec 2023: lookup.cpp created (PR ggml-org#4484) with same pattern
- Nov 2024: default n_ctx changed to 4096 (PR ggml-org#10136) - bug dormant
- Dec 2025: default n_ctx changed to 0 (PR ggml-org#16653) - bug activated

The bug was dormant for 2+ years because params.n_ctx defaulted to 512,
then 4096. PR ggml-org#16653 changed it to 0 for GPU auto-fitting, triggering
the crash.

Fix: Use llama_n_ctx(ctx) to get the actual runtime context size, matching
the pattern already used elsewhere in lookup.cpp (line 72) and in
speculative.cpp/speculative-simple.cpp.

Tested: llama-lookup now works without -c flag (12.5% acceptance on
Gemma-3-1B).

Note: llama-lookahead has a separate pre-existing issue with sequence
initialization (n_seq_max=1 vs W+G+1 needed) that is unrelated to this fix.

* lookahead: fix n_seq_max and kv_unified configuration

Lookahead decoding requires:
- W + G + 1 = 31 sequences for parallel Jacobi decoding
- Unified KV cache for coupled sequences in batch splitting

These requirements were broken after PR ggml-org#14482 changed validation logic.

Consolidates fix from PR ggml-org#18730 per maintainer request.

Commit message drafted with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

devops improvements to build systems and github actions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: n_ctx will reuse n_ctx_train when --ctx_size not set and make deepseek-v2 models meet out of memory crash even on a small output length.

3 participants