|
| 1 | +import { modelKey, normalizeProviderId } from "../../agents/model-selection-normalize.js"; |
| 2 | +import { |
| 3 | + resolveModelRefFromString, |
| 4 | + type ModelAliasIndex, |
| 5 | +} from "../../agents/model-selection-shared.js"; |
| 6 | +import { normalizeLowercaseStringOrEmpty } from "../../shared/string-coerce.js"; |
| 7 | + |
| 8 | +export type ModelDirectiveSelection = { |
| 9 | + provider: string; |
| 10 | + model: string; |
| 11 | + isDefault: boolean; |
| 12 | + alias?: string; |
| 13 | +}; |
| 14 | + |
| 15 | +const FUZZY_VARIANT_TOKENS = [ |
| 16 | + "lightning", |
| 17 | + "preview", |
| 18 | + "mini", |
| 19 | + "fast", |
| 20 | + "turbo", |
| 21 | + "lite", |
| 22 | + "beta", |
| 23 | + "small", |
| 24 | + "nano", |
| 25 | +]; |
| 26 | + |
| 27 | +function boundedLevenshteinDistance(a: string, b: string, maxDistance: number): number | null { |
| 28 | + if (a === b) { |
| 29 | + return 0; |
| 30 | + } |
| 31 | + if (!a || !b) { |
| 32 | + return null; |
| 33 | + } |
| 34 | + const aLen = a.length; |
| 35 | + const bLen = b.length; |
| 36 | + if (Math.abs(aLen - bLen) > maxDistance) { |
| 37 | + return null; |
| 38 | + } |
| 39 | + |
| 40 | + // Standard DP with early exit. O(maxDistance * minLen) in common cases. |
| 41 | + const prev = Array.from({ length: bLen + 1 }, (_, idx) => idx); |
| 42 | + const curr = Array.from({ length: bLen + 1 }, () => 0); |
| 43 | + |
| 44 | + for (let i = 1; i <= aLen; i++) { |
| 45 | + curr[0] = i; |
| 46 | + let rowMin = curr[0]; |
| 47 | + |
| 48 | + const aChar = a.charCodeAt(i - 1); |
| 49 | + for (let j = 1; j <= bLen; j++) { |
| 50 | + const cost = aChar === b.charCodeAt(j - 1) ? 0 : 1; |
| 51 | + curr[j] = Math.min(prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost); |
| 52 | + if (curr[j] < rowMin) { |
| 53 | + rowMin = curr[j]; |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + if (rowMin > maxDistance) { |
| 58 | + return null; |
| 59 | + } |
| 60 | + |
| 61 | + for (let j = 0; j <= bLen; j++) { |
| 62 | + prev[j] = curr[j] ?? 0; |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + const dist = prev[bLen] ?? null; |
| 67 | + if (dist == null || dist > maxDistance) { |
| 68 | + return null; |
| 69 | + } |
| 70 | + return dist; |
| 71 | +} |
| 72 | + |
| 73 | +function scoreFuzzyMatch(params: { |
| 74 | + provider: string; |
| 75 | + model: string; |
| 76 | + fragment: string; |
| 77 | + aliasIndex: ModelAliasIndex; |
| 78 | + defaultProvider: string; |
| 79 | + defaultModel: string; |
| 80 | +}): { |
| 81 | + score: number; |
| 82 | + isDefault: boolean; |
| 83 | + variantCount: number; |
| 84 | + variantMatchCount: number; |
| 85 | + modelLength: number; |
| 86 | + key: string; |
| 87 | +} { |
| 88 | + const provider = normalizeProviderId(params.provider); |
| 89 | + const model = params.model; |
| 90 | + const fragment = normalizeLowercaseStringOrEmpty(params.fragment); |
| 91 | + const providerLower = normalizeLowercaseStringOrEmpty(provider); |
| 92 | + const modelLower = normalizeLowercaseStringOrEmpty(model); |
| 93 | + const haystack = `${providerLower}/${modelLower}`; |
| 94 | + const key = modelKey(provider, model); |
| 95 | + |
| 96 | + const scoreFragment = ( |
| 97 | + value: string, |
| 98 | + weights: { exact: number; starts: number; includes: number }, |
| 99 | + ) => { |
| 100 | + if (!fragment) { |
| 101 | + return 0; |
| 102 | + } |
| 103 | + let score = 0; |
| 104 | + if (value === fragment) { |
| 105 | + score = Math.max(score, weights.exact); |
| 106 | + } |
| 107 | + if (value.startsWith(fragment)) { |
| 108 | + score = Math.max(score, weights.starts); |
| 109 | + } |
| 110 | + if (value.includes(fragment)) { |
| 111 | + score = Math.max(score, weights.includes); |
| 112 | + } |
| 113 | + return score; |
| 114 | + }; |
| 115 | + |
| 116 | + let score = 0; |
| 117 | + score += scoreFragment(haystack, { exact: 220, starts: 140, includes: 110 }); |
| 118 | + score += scoreFragment(providerLower, { |
| 119 | + exact: 180, |
| 120 | + starts: 120, |
| 121 | + includes: 90, |
| 122 | + }); |
| 123 | + score += scoreFragment(modelLower, { |
| 124 | + exact: 160, |
| 125 | + starts: 110, |
| 126 | + includes: 80, |
| 127 | + }); |
| 128 | + |
| 129 | + // Best-effort typo tolerance for common near-misses like "claud" vs "claude". |
| 130 | + // Bounded to keep this cheap across large model sets. |
| 131 | + const distModel = boundedLevenshteinDistance(fragment, modelLower, 3); |
| 132 | + if (distModel != null) { |
| 133 | + score += (3 - distModel) * 70; |
| 134 | + } |
| 135 | + |
| 136 | + const aliases = params.aliasIndex.byKey.get(key) ?? []; |
| 137 | + for (const alias of aliases) { |
| 138 | + score += scoreFragment(normalizeLowercaseStringOrEmpty(alias), { |
| 139 | + exact: 140, |
| 140 | + starts: 90, |
| 141 | + includes: 60, |
| 142 | + }); |
| 143 | + } |
| 144 | + |
| 145 | + if (modelLower.startsWith(providerLower)) { |
| 146 | + score += 30; |
| 147 | + } |
| 148 | + |
| 149 | + const fragmentVariants = FUZZY_VARIANT_TOKENS.filter((token) => fragment.includes(token)); |
| 150 | + const modelVariants = FUZZY_VARIANT_TOKENS.filter((token) => modelLower.includes(token)); |
| 151 | + const variantMatchCount = fragmentVariants.filter((token) => modelLower.includes(token)).length; |
| 152 | + const variantCount = modelVariants.length; |
| 153 | + if (fragmentVariants.length === 0 && variantCount > 0) { |
| 154 | + score -= variantCount * 30; |
| 155 | + } else if (fragmentVariants.length > 0) { |
| 156 | + if (variantMatchCount > 0) { |
| 157 | + score += variantMatchCount * 40; |
| 158 | + } |
| 159 | + if (variantMatchCount === 0) { |
| 160 | + score -= 20; |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + const defaultProvider = normalizeProviderId(params.defaultProvider); |
| 165 | + const isDefault = provider === defaultProvider && model === params.defaultModel; |
| 166 | + if (isDefault) { |
| 167 | + score += 20; |
| 168 | + } |
| 169 | + |
| 170 | + return { |
| 171 | + score, |
| 172 | + isDefault, |
| 173 | + variantCount, |
| 174 | + variantMatchCount, |
| 175 | + modelLength: modelLower.length, |
| 176 | + key, |
| 177 | + }; |
| 178 | +} |
| 179 | + |
| 180 | +export function resolveModelDirectiveSelection(params: { |
| 181 | + raw: string; |
| 182 | + defaultProvider: string; |
| 183 | + defaultModel: string; |
| 184 | + aliasIndex: ModelAliasIndex; |
| 185 | + allowedModelKeys: Set<string>; |
| 186 | +}): { selection?: ModelDirectiveSelection; error?: string } { |
| 187 | + const { raw, defaultProvider, defaultModel, aliasIndex, allowedModelKeys } = params; |
| 188 | + |
| 189 | + const rawTrimmed = raw.trim(); |
| 190 | + const rawLower = normalizeLowercaseStringOrEmpty(rawTrimmed); |
| 191 | + |
| 192 | + const pickAliasForKey = (provider: string, model: string): string | undefined => |
| 193 | + aliasIndex.byKey.get(modelKey(provider, model))?.[0]; |
| 194 | + |
| 195 | + const buildSelection = (provider: string, model: string): ModelDirectiveSelection => { |
| 196 | + const alias = pickAliasForKey(provider, model); |
| 197 | + return { |
| 198 | + provider, |
| 199 | + model, |
| 200 | + isDefault: provider === defaultProvider && model === defaultModel, |
| 201 | + ...(alias ? { alias } : undefined), |
| 202 | + }; |
| 203 | + }; |
| 204 | + |
| 205 | + const resolveFuzzy = (params: { |
| 206 | + provider?: string; |
| 207 | + fragment: string; |
| 208 | + }): { selection?: ModelDirectiveSelection; error?: string } => { |
| 209 | + const fragment = normalizeLowercaseStringOrEmpty(params.fragment); |
| 210 | + if (!fragment) { |
| 211 | + return {}; |
| 212 | + } |
| 213 | + |
| 214 | + const providerFilter = params.provider ? normalizeProviderId(params.provider) : undefined; |
| 215 | + |
| 216 | + const candidates: Array<{ provider: string; model: string }> = []; |
| 217 | + for (const key of allowedModelKeys) { |
| 218 | + const slash = key.indexOf("/"); |
| 219 | + if (slash <= 0) { |
| 220 | + continue; |
| 221 | + } |
| 222 | + const provider = normalizeProviderId(key.slice(0, slash)); |
| 223 | + const model = key.slice(slash + 1); |
| 224 | + if (providerFilter && provider !== providerFilter) { |
| 225 | + continue; |
| 226 | + } |
| 227 | + candidates.push({ provider, model }); |
| 228 | + } |
| 229 | + |
| 230 | + // Also allow partial alias matches when the user didn't specify a provider. |
| 231 | + if (!params.provider) { |
| 232 | + const aliasMatches: Array<{ provider: string; model: string }> = []; |
| 233 | + for (const [aliasKey, entry] of aliasIndex.byAlias.entries()) { |
| 234 | + if (!aliasKey.includes(fragment)) { |
| 235 | + continue; |
| 236 | + } |
| 237 | + aliasMatches.push({ |
| 238 | + provider: entry.ref.provider, |
| 239 | + model: entry.ref.model, |
| 240 | + }); |
| 241 | + } |
| 242 | + for (const match of aliasMatches) { |
| 243 | + const key = modelKey(match.provider, match.model); |
| 244 | + if (!allowedModelKeys.has(key)) { |
| 245 | + continue; |
| 246 | + } |
| 247 | + if (!candidates.some((c) => c.provider === match.provider && c.model === match.model)) { |
| 248 | + candidates.push(match); |
| 249 | + } |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + if (candidates.length === 0) { |
| 254 | + return {}; |
| 255 | + } |
| 256 | + |
| 257 | + const scored = candidates |
| 258 | + .map((candidate) => { |
| 259 | + const details = scoreFuzzyMatch({ |
| 260 | + provider: candidate.provider, |
| 261 | + model: candidate.model, |
| 262 | + fragment, |
| 263 | + aliasIndex, |
| 264 | + defaultProvider, |
| 265 | + defaultModel, |
| 266 | + }); |
| 267 | + return Object.assign({ candidate }, details); |
| 268 | + }) |
| 269 | + .toSorted((a, b) => { |
| 270 | + if (b.score !== a.score) { |
| 271 | + return b.score - a.score; |
| 272 | + } |
| 273 | + if (a.isDefault !== b.isDefault) { |
| 274 | + return a.isDefault ? -1 : 1; |
| 275 | + } |
| 276 | + if (a.variantMatchCount !== b.variantMatchCount) { |
| 277 | + return b.variantMatchCount - a.variantMatchCount; |
| 278 | + } |
| 279 | + if (a.variantCount !== b.variantCount) { |
| 280 | + return a.variantCount - b.variantCount; |
| 281 | + } |
| 282 | + if (a.modelLength !== b.modelLength) { |
| 283 | + return a.modelLength - b.modelLength; |
| 284 | + } |
| 285 | + return a.key.localeCompare(b.key); |
| 286 | + }); |
| 287 | + |
| 288 | + const bestScored = scored[0]; |
| 289 | + const best = bestScored?.candidate; |
| 290 | + if (!best || !bestScored) { |
| 291 | + return {}; |
| 292 | + } |
| 293 | + |
| 294 | + const minScore = providerFilter ? 90 : 120; |
| 295 | + if (bestScored.score < minScore) { |
| 296 | + return {}; |
| 297 | + } |
| 298 | + |
| 299 | + return { selection: buildSelection(best.provider, best.model) }; |
| 300 | + }; |
| 301 | + |
| 302 | + const resolved = resolveModelRefFromString({ |
| 303 | + raw: rawTrimmed, |
| 304 | + defaultProvider, |
| 305 | + aliasIndex, |
| 306 | + }); |
| 307 | + |
| 308 | + if (!resolved) { |
| 309 | + const fuzzy = resolveFuzzy({ fragment: rawTrimmed }); |
| 310 | + if (fuzzy.selection || fuzzy.error) { |
| 311 | + return fuzzy; |
| 312 | + } |
| 313 | + return { |
| 314 | + error: `Unrecognized model "${rawTrimmed}". Use /models to list providers, or /models <provider> to list models.`, |
| 315 | + }; |
| 316 | + } |
| 317 | + |
| 318 | + const resolvedKey = modelKey(resolved.ref.provider, resolved.ref.model); |
| 319 | + if (allowedModelKeys.size === 0 || allowedModelKeys.has(resolvedKey)) { |
| 320 | + return { |
| 321 | + selection: { |
| 322 | + provider: resolved.ref.provider, |
| 323 | + model: resolved.ref.model, |
| 324 | + isDefault: resolved.ref.provider === defaultProvider && resolved.ref.model === defaultModel, |
| 325 | + alias: resolved.alias, |
| 326 | + }, |
| 327 | + }; |
| 328 | + } |
| 329 | + |
| 330 | + // If the user specified a provider/model but the exact model isn't allowed, |
| 331 | + // attempt a fuzzy match within that provider. |
| 332 | + if (rawLower.includes("/")) { |
| 333 | + const slash = rawTrimmed.indexOf("/"); |
| 334 | + const provider = normalizeProviderId(rawTrimmed.slice(0, slash).trim()); |
| 335 | + const fragment = rawTrimmed.slice(slash + 1).trim(); |
| 336 | + const fuzzy = resolveFuzzy({ provider, fragment }); |
| 337 | + if (fuzzy.selection || fuzzy.error) { |
| 338 | + return fuzzy; |
| 339 | + } |
| 340 | + } |
| 341 | + |
| 342 | + // Otherwise, try fuzzy matching across allowlisted models. |
| 343 | + const fuzzy = resolveFuzzy({ fragment: rawTrimmed }); |
| 344 | + if (fuzzy.selection || fuzzy.error) { |
| 345 | + return fuzzy; |
| 346 | + } |
| 347 | + |
| 348 | + return { |
| 349 | + error: `Model "${resolved.ref.provider}/${resolved.ref.model}" is not allowed. Use /models to list providers, or /models <provider> to list models.`, |
| 350 | + }; |
| 351 | +} |
0 commit comments