Add multi-branch (beam) search to the constrained decoder#518
Conversation
Greedy argmax can commit to a locally-best token that leads nowhere good. This adds a beam search that explores several short continuations at once and keeps the highest-scoring one (by mean log-probability), reusing the same token profile, no-repeat-ngram guard, sentence-boundary stop, single-line masking, and confidence suppression as the greedy path. The algorithm runs against a BeamLogitsProvider closure so it is fully unit-tested against a scripted fake (no model needed); the live llama adapter (beamLogits/syncBeamSequence) is a thin KV-sync method that trims the shared sequence to a branch's prefix and re-accepts the rest. A closure rather than a protocol object because CotabbyInferenceEngine is a noncopyable C++ type. Gated behind the existing cotabbyConstrainedDecoderEnabled flag (default off) plus a new cotabbyConstrainedBeamWidth (default 1 = the existing greedy decode), so shipped behavior is unchanged. Enabling beam by default still needs on-device quality/latency evaluation.
| let blocked = RepetitionGuard.blockedTokens( | ||
| history: branch.tokenIDs, | ||
| ngramSize: configuration.noRepeatNgramSize | ||
| ) | ||
| let candidates = ConstrainedSampler.rankedAdmissibleTokens( | ||
| logits: logits, | ||
| profile: profile, | ||
| admissibleTokenIDs: nil, | ||
| topK: configuration.topK, | ||
| blockedTokenIDs: blocked | ||
| ) | ||
| for tokenID in candidates { | ||
| if profile.isEndOfGeneration(tokenID) { | ||
| completed.append(branch) | ||
| continue | ||
| } | ||
| if isSingleLine, profile.isNewline(tokenID) { | ||
| completed.append(branch) | ||
| continue | ||
| } | ||
| let tokenBytes = profile.bytes(for: tokenID) | ||
| let child = extend(branch, by: tokenID, tokenBytes: tokenBytes, logits: logits) | ||
| if Self.completesSentence(child.bytes, lastTokenBytes: tokenBytes) { | ||
| completed.append(child) | ||
| } else { | ||
| live.append(child) | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Branch silently dropped when no admissible tokens remain
When rankedAdmissibleTokens returns an empty list (all tokens excluded or blocked by the no-repeat-ngram guard), the for tokenID in candidates loop body never executes. The branch is appended to neither live nor completed, so it disappears silently. In contrast, the greedy constrained decoder breaks with stopReason = "no_admissible_token" and returns whatever it has generated so far. A branch with partial output that cannot be extended should be treated as a completion (completed.append(branch)) rather than discarded, to stay consistent with the greedy behavior.
| /// The admissible token ids for a step, ranked highest-logit first. Survivors are the same set | ||
| /// `selectToken` would consider — in-range, not `profile.isExcluded`, not in `blockedTokenIDs`, | ||
| /// and, when `admissibleTokenIDs` is non-nil, members of that set — and at most `topK` are | ||
| /// returned. This is the multi-candidate form of `selectToken`: the beam search expands a branch | ||
| /// across these instead of committing to the single best. Ties break by lower id for determinism. |
There was a problem hiding this comment.
Docstring overstates equivalence with
selectToken's topK semantics
The comment says "Survivors are the same set selectToken would consider", but the two functions handle topK differently. selectToken pre-filters to the top-K tokens by raw logit (via candidatePool) and then applies constraints — so a token at raw-logit rank K+1 is never seen even if it would pass all filters. rankedAdmissibleTokens filters the full vocabulary first and then takes the top-K survivors, so a token at raw-logit rank K+1 that passes all filters does appear. The beam search therefore explores a strictly broader set than greedy for any topK < vocabSize, which is intentional, but the "same set" claim is inaccurate and may confuse future maintainers comparing the two code paths.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Summary
Adds a multi-branch (beam) search to the constrained decoder so it can explore several short
continuations at once and keep the highest-scoring one, instead of committing to a single greedy
token at each step (greedy argmax can pick a locally-best token that leads nowhere good).
The search is written against a small
BeamLogitsProviderclosure — "give me the next-token logitsfor this branch's token path" — so the whole algorithm is unit-tested against a scripted closure that
returns fixed logits, with no model required. The live llama adapter is a thin KV-sync method on
the runtime (
beamLogits/syncBeamSequence): before reading a branch's logits it trims the sharedsequence's KV cache back to the longest shared prefix and re-accepts the remainder, all under the
existing
autocompleteLock, and the existing KV-trim defer restores prompt-only state afterward. (Aclosure rather than a protocol object because the inference engine is a noncopyable C++ type that
cannot be stored in a separate adapter object.)
Each branch reuses the same machinery as the greedy path: the token profile, the no-repeat-ngram
guard, the sentence-boundary stop, single-line newline masking, and the confidence-suppression gate.
Branches are scored by cumulative log-probability and the result is ranked by mean log-probability.
Gated behind the existing developer flags:
cotabbyConstrainedDecoderEnabled(default off) plus a newcotabbyConstrainedBeamWidth(default 1 = the existing single-path greedy decode). The shippingsampler path is untouched.
Validation
Linked issues
None. Brings the constrained decoder up to a real multi-branch search.
Risk / rollout notes
cotabbyConstrainedBeamWidth > 1. With the defaults, behavior is byte-for-byte the existinggreedy/sampler path. No change to shipped suggestion behavior.
exercised by tests is the ~25-line
EngineBeamStepperKV-sync (it uses the sametrimKV/acceptToken/getNextTokenLogitscalls the greedy decoder already relies on). Enabling beam bydefault still needs on-device evaluation of completion quality and per-keystroke latency (beam width
W decodes up to ~W× the greedy token count) — that is the next step before promoting the flag.
trim-and-re-accept on each branch switch; not required for correctness.
Greptile Summary
This PR adds a multi-branch (beam) search mode to the constrained decoder, letting it explore several continuations per step and return the highest mean-log-probability one instead of committing greedily to each token. The new path is double-gated (
cotabbyConstrainedDecoderEnabledANDcotabbyConstrainedBeamWidth > 1), so shipped behavior is unchanged.ConstrainedBeamSearch.swiftimplements the search engine as a pure closure-based algorithm (BeamLogitsProvider), reusing the same token profile, no-repeat-ngram guard, single-line newline masking, and sentence-boundary stop as the greedy path; 8 unit tests exercise all core stop conditions against a scripted logits closure.LlamaRuntimeCoreaddsrunConstrainedBeamDecodeand asyncBeamSequenceKV-cache sync helper that trims the shared sequence to the longest common prefix of the current and target paths before re-accepting branch tokens, all under the existingautocompleteLock.ConstrainedSamplergainsrankedAdmissibleTokens(the multi-candidate analogue ofselectToken) andLlamaGenerationOptionsgainsbeamWidth: Int = 1, intentionally excluded fromSamplingFingerprintto preserve KV reuse.Confidence Score: 3/5
Safe to merge as-shipped since the beam path is double-gated and off by default, but the decode loop holds autocompleteLock without any cancellation check, which will degrade responsiveness noticeably during dogfood testing with the flags enabled.
The new beam decode path does not check Task.isCancelled between steps, unlike both the sampled and greedy decoders that break early on cancel. Under fast typing, each new keystroke supersedes the previous request; for greedy/sampled the lock is freed after 1-2 tokens, but for beam the lock is held for the full W x N step search, making every subsequent request wait. A secondary gap is that branches with no remaining admissible tokens are silently discarded rather than collected as partial completions, inconsistent with the greedy decoder behavior.
Cotabby/Services/Runtime/LlamaRuntimeCore.swift (missing cancellation in runConstrainedBeamDecode) and Cotabby/Support/ConstrainedBeamSearch.swift (silent branch loss on empty candidate list).
Important Files Changed
Sequence Diagram
sequenceDiagram participant SE as LlamaSuggestionEngine participant RC as LlamaRuntimeCore participant BS as ConstrainedBeamSearch participant KV as Engine (KV cache) SE->>RC: "generate(options beamWidth>1)" RC->>RC: obtain sequence, hold autocompleteLock RC->>BS: search(nextLogits closure, profile, config) loop each beam step loop each live branch BS->>RC: nextLogits(branch.tokenIDs) RC->>KV: trimKV to common prefix RC->>KV: acceptToken for remainder KV-->>RC: logits[vocabSize] RC-->>BS: [Float] BS->>BS: expand across topK candidates BS->>BS: route stops to completed, rest to live end BS->>BS: prune live to beamWidth end BS-->>RC: [BeamCandidate] sorted by meanLogprob RC->>RC: shouldSuppress check RC->>KV: trimKV to prompt only RC-->>SE: best.textComments Outside Diff (1)
Cotabby/Services/Runtime/LlamaRuntimeCore.swift, line 419-542 (link)Both the sampled decoder (
runEngineSampledDecode) and the greedy constrained decoder (runConstrainedDecode) checkTask.isCancelledat each token step and break early, releasingautocompleteLockpromptly.runConstrainedBeamDecodeholds the lock for the entireConstrainedBeamSearch.searchcall with no cancellation hook anywhere in the search loop. During fast typing, every new keystroke supersedes the in-flight generation — for greedy/sampled this means only ~1–2 tokens are decoded before the lock is freed; for beam with width W and budget N, the next request waits up to W×N KV syncs before it can proceed.Reviews (1): Last reviewed commit: "Add multi-branch (beam) search to the co..." | Re-trigger Greptile