Skip to content

Commit 0399d3c

Browse files
[Search][Playground] Fix loading context limit for EIS model (#225360)
## Summary Updates the playground to properly load the right context token limit when using the EIS model as well as sending the model to the chat endpoint. Additionally updated the backend to load the correct chat parameters for the EIS model. Fixes issue where too large context error was received from connector instead of being cause by playground. Before: ![image](https://github.com/user-attachments/assets/578f841c-3ef7-4ec7-a2b3-d3c03fd8387f) After: ![image](https://github.com/user-attachments/assets/52a6fa9f-d1d6-42ab-acc1-83171bb2eb25) ### Checklist - [ ] Any text added follows [EUI's writing guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses sentence case text and includes [i18n support](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md) - [ ] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [ ] If a plugin configuration key changed, check if it needs to be allowlisted in the cloud and added to the [docker list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker) - [ ] This was checked for breaking HTTP API changes, and any breaking changes have been approved by the breaking-change committee. The `release_note:breaking` label should be applied in these situations. - [ ] [Flaky Test Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was used on any tests changed - [x] The PR description includes the appropriate Release Notes section, and the correct `release_note:*` label is applied per the [guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process) ## Release note Fixes issue in Search Playground where context limit errors were not handled well when using the Elastic Managed LLM. (cherry picked from commit 34dfd62) # Conflicts: # x-pack/platform/plugins/private/translations/translations/fr-FR.json # x-pack/platform/plugins/private/translations/translations/ja-JP.json # x-pack/platform/plugins/private/translations/translations/zh-CN.json # x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts # x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts # x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts # x-pack/solutions/search/plugins/search_playground/server/routes.ts
1 parent c0a42a2 commit 0399d3c

11 files changed

Lines changed: 140 additions & 25 deletions

File tree

x-pack/platform/plugins/private/translations/translations/fr-FR.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34850,8 +34850,7 @@
3485034850
"xpack.searchPlayground.geminiConnectorTitle": "Gemini",
3485134851
"xpack.searchPlayground.header.view.chat": "Chat",
3485234852
"xpack.searchPlayground.header.view.preview": "Aperçu",
34853-
"xpack.searchPlayground.header.view.query": "Recherche",
34854-
"xpack.searchPlayground.inferenceModel": "{name} (connecteur IA)",
34853+
"xpack.searchPlayground.header.view.query": "Requête",
3485534854
"xpack.searchPlayground.loadConnectorsError": "Erreur lors du chargement des connecteurs. Veuillez vérifier votre configuration et réessayer.",
3485634855
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
3485734856
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",

x-pack/platform/plugins/private/translations/translations/ja-JP.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34826,7 +34826,6 @@
3482634826
"xpack.searchPlayground.header.view.chat": "チャット",
3482734827
"xpack.searchPlayground.header.view.preview": "プレビュー",
3482834828
"xpack.searchPlayground.header.view.query": "クエリー",
34829-
"xpack.searchPlayground.inferenceModel": "{name}(AIコネクター)",
3483034829
"xpack.searchPlayground.loadConnectorsError": "コネクターの読み込みエラーです。構成を確認して、再試行してください。",
3483134830
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
3483234831
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",

x-pack/platform/plugins/private/translations/translations/zh-CN.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34885,7 +34885,6 @@
3488534885
"xpack.searchPlayground.header.view.chat": "聊天",
3488634886
"xpack.searchPlayground.header.view.preview": "预览",
3488734887
"xpack.searchPlayground.header.view.query": "查询",
34888-
"xpack.searchPlayground.inferenceModel": "{name}(AI 连接器)",
3488934888
"xpack.searchPlayground.loadConnectorsError": "加载连接器进出错。请检查您的配置,然后重试。",
3489034889
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
3489134890
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",

x-pack/solutions/search/plugins/search_playground/common/models.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* 2.0.
66
*/
77

8+
import { elasticModelIds } from '@kbn/inference-common';
89
import { ModelProvider, LLMs } from './types';
910

1011
export const MODELS: ModelProvider[] = [
@@ -50,4 +51,10 @@ export const MODELS: ModelProvider[] = [
5051
promptTokenLimit: 2097152,
5152
provider: LLMs.gemini,
5253
},
54+
{
55+
name: 'Elastic Managed LLM',
56+
model: elasticModelIds.RainbowSprinkles,
57+
promptTokenLimit: 200000,
58+
provider: LLMs.inference,
59+
},
5360
];

x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ const mapLlmToModels: Record<
2626
icon: string | ((connector: PlaygroundConnector) => string);
2727
getModels: (
2828
connectorName: string,
29-
includeName: boolean
29+
includeName: boolean,
30+
modelId?: string
3031
) => Array<{ label: string; value?: string; promptTokenLimit?: number }>;
3132
}
3233
> = {
@@ -85,12 +86,11 @@ const mapLlmToModels: Record<
8586
? SERVICE_PROVIDERS[connector.config.provider].icon
8687
: '';
8788
},
88-
getModels: (connectorName) => [
89+
getModels: (connectorName, _, modelId) => [
8990
{
90-
label: i18n.translate('xpack.searchPlayground.inferenceModel', {
91-
defaultMessage: '{name}',
92-
values: { name: connectorName },
93-
}),
91+
label: connectorName,
92+
value: modelId,
93+
promptTokenLimit: MODELS.find((m) => m.model === modelId)?.promptTokenLimit,
9494
},
9595
],
9696
},
@@ -126,7 +126,13 @@ export const useLLMsModels = (): LLMModel[] => {
126126
return [
127127
...result,
128128
...llmParams
129-
.getModels(connector.name, false)
129+
.getModels(
130+
connector.name,
131+
false,
132+
isInferenceActionConnector(connector)
133+
? connector.config?.providerConfig?.model_id
134+
: undefined
135+
)
130136
.map(({ label, value, promptTokenLimit }) => ({
131137
id: connector?.id + label,
132138
name: label,

x-pack/solutions/search/plugins/search_playground/public/providers/form_provider.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export const FormProvider: React.FC<React.PropsWithChildren<FormProviderProps>>
8484
}, [form, storage]);
8585

8686
useEffect(() => {
87+
if (models.length === 0) return; // don't continue if there are no models
8788
const defaultModel = models.find((model) => !model.disabled);
8889
const currentModel = form.getValues(ChatFormFields.summarizationModel);
8990

x-pack/solutions/search/plugins/search_playground/public/types.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ export interface LLMModel {
231231

232232
export type { ActionConnector, UserConfiguredActionConnector };
233233
export type InferenceActionConnector = ActionConnector & {
234-
config: { provider: ServiceProviderKeys; inferenceId: string };
234+
config: {
235+
providerConfig?: {
236+
model_id?: string;
237+
};
238+
provider: ServiceProviderKeys;
239+
inferenceId: string;
240+
};
235241
};
236242
export type PlaygroundConnector = ActionConnector & { title: string; type: LLMs };

x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
2121
import { KibanaRequest, Logger } from '@kbn/core/server';
2222
import { PluginStartContract as ActionsPluginStartContract } from '@kbn/actions-plugin/server';
23+
import { elasticModelIds } from '@kbn/inference-common';
2324

2425
jest.mock('@kbn/langchain/server', () => {
2526
const original = jest.requireActual('@kbn/langchain/server');
@@ -230,4 +231,67 @@ describe('getChatParams', () => {
230231
});
231232
expect(result.chatPrompt).toContain('How does it work?');
232233
});
234+
235+
it('returns the correct params for the EIS connector', async () => {
236+
const mockConnector = {
237+
id: 'elastic-llm',
238+
actionTypeId: INFERENCE_CONNECTOR_ID,
239+
config: {
240+
providerConfig: {
241+
model_id: elasticModelIds.RainbowSprinkles,
242+
},
243+
},
244+
};
245+
mockActionsClient.get.mockResolvedValue(mockConnector);
246+
247+
const result = await getChatParams(
248+
{
249+
connectorId: 'elastic-llm',
250+
prompt: 'How does it work?',
251+
citations: false,
252+
},
253+
{ actions, request, logger }
254+
);
255+
256+
expect(result).toMatchObject({
257+
connector: mockConnector,
258+
summarizationModel: elasticModelIds.RainbowSprinkles,
259+
});
260+
261+
expect(Prompt).toHaveBeenCalledWith('How does it work?', {
262+
citations: false,
263+
context: true,
264+
type: 'openai',
265+
});
266+
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
267+
type: 'openai',
268+
});
269+
});
270+
271+
it('it returns provided model with EIS connector', async () => {
272+
const mockConnector = {
273+
id: 'elastic-llm',
274+
actionTypeId: INFERENCE_CONNECTOR_ID,
275+
config: {
276+
providerConfig: {
277+
model_id: elasticModelIds.RainbowSprinkles,
278+
},
279+
},
280+
};
281+
mockActionsClient.get.mockResolvedValue(mockConnector);
282+
283+
const result = await getChatParams(
284+
{
285+
connectorId: 'elastic-llm',
286+
model: 'foo-bar',
287+
prompt: 'How does it work?',
288+
citations: false,
289+
},
290+
{ actions, request, logger }
291+
);
292+
293+
expect(result).toMatchObject({
294+
summarizationModel: 'foo-bar',
295+
});
296+
});
233297
});

x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
import { GEMINI_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/gemini/constants';
2222
import { INFERENCE_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/inference/constants';
2323
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
24+
import { isEISConnector } from '../utils/eis';
2425

2526
export const getChatParams = async (
2627
{
@@ -43,9 +44,11 @@ export const getChatParams = async (
4344
chatPrompt: string;
4445
questionRewritePrompt: string;
4546
connector: Connector;
47+
summarizationModel?: string;
4648
}> => {
4749
const abortController = new AbortController();
4850
const abortSignal = abortController.signal;
51+
let summarizationModel = model;
4952
const actionsClient = await actions.getActionsClientWithRequest(request);
5053
const connector = await actionsClient.get({ id: connectorId });
5154
let chatModel;
@@ -55,12 +58,17 @@ export const getChatParams = async (
5558

5659
switch (connector.actionTypeId) {
5760
case INFERENCE_CONNECTOR_ID:
61+
if (isEISConnector(connector)) {
62+
if (!summarizationModel && connector.config?.providerConfig?.model_id) {
63+
summarizationModel = connector.config?.providerConfig?.model_id;
64+
}
65+
}
5866
llmType = 'inference';
5967
chatModel = new ActionsClientChatOpenAI({
6068
actionsClient,
6169
logger,
6270
connectorId,
63-
model: connector?.config?.defaultModel,
71+
model: summarizationModel || connector?.config?.defaultModel,
6472
llmType,
6573
temperature: getDefaultArguments(llmType).temperature,
6674
// prevents the agent from retrying on failure
@@ -146,5 +154,11 @@ export const getChatParams = async (
146154
throw new Error('Invalid connector id');
147155
}
148156

149-
return { chatModel, chatPrompt, questionRewritePrompt, connector };
157+
return {
158+
chatModel,
159+
chatPrompt,
160+
questionRewritePrompt,
161+
connector,
162+
summarizationModel: summarizationModel || connector?.config?.defaultModel,
163+
};
150164
};

x-pack/solutions/search/plugins/search_playground/server/routes.ts

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,16 @@ export function defineRoutes({
115115
es_client: client.asCurrentUser,
116116
} as AssistClientOptionsWithClient);
117117
const { messages, data } = request.body;
118-
const { chatModel, chatPrompt, questionRewritePrompt, connector } = await getChatParams(
119-
{
120-
connectorId: data.connector_id,
121-
model: data.summarization_model,
122-
citations: data.citations,
123-
prompt: data.prompt,
124-
},
125-
{ actions, logger, request }
126-
);
118+
const { chatModel, chatPrompt, questionRewritePrompt, connector, summarizationModel } =
119+
await getChatParams(
120+
{
121+
connectorId: data.connector_id,
122+
model: data.summarization_model,
123+
citations: data.citations,
124+
prompt: data.prompt,
125+
},
126+
{ actions, logger, request }
127+
);
127128

128129
let sourceFields = {};
129130

@@ -139,7 +140,7 @@ export function defineRoutes({
139140
throw Error(e);
140141
}
141142

142-
const model = MODELS.find((m) => m.model === data.summarization_model);
143+
const model = MODELS.find((m) => m.model === summarizationModel);
143144
const modelPromptLimit = model?.promptTokenLimit;
144145

145146
const chain = ConversationalChain({
@@ -162,7 +163,7 @@ export function defineRoutes({
162163
connectorType:
163164
connector.actionTypeId +
164165
(connector.config?.apiProvider ? `-${connector.config.apiProvider}` : ''),
165-
model: data.summarization_model ?? '',
166+
model: summarizationModel ?? '',
166167
isCitationsEnabled: data.citations,
167168
});
168169

0 commit comments

Comments
 (0)