Skip to content

Commit 66b5892

Browse files
authored
Introduce streamMode for useChat / useCompletion. (#1350)
1 parent f272b01 commit 66b5892

23 files changed

+879
-455
lines changed

.changeset/empty-windows-think.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
Add streamMode parameter to useChat and useCompletion.

examples/solidstart-openai/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"vite": "^4.1.4"
1717
},
1818
"dependencies": {
19+
"@ai-sdk/openai": "latest",
1920
"@solidjs/meta": "0.29.3",
2021
"@solidjs/router": "0.8.2",
2122
"ai": "latest",
Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
1-
import { OpenAIStream, StreamingTextResponse } from 'ai';
2-
import OpenAI from 'openai';
1+
import { openai } from '@ai-sdk/openai';
2+
import { StreamingTextResponse, experimental_streamText } from 'ai';
33
import { APIEvent } from 'solid-start/api';
44

5-
// Create an OpenAI API client
6-
const openai = new OpenAI({
7-
apiKey: process.env['OPENAI_API_KEY'] || '',
8-
});
9-
105
export const POST = async (event: APIEvent) => {
11-
// Extract the `prompt` from the body of the request
12-
const { messages } = await event.request.json();
6+
try {
7+
const { messages } = await event.request.json();
138

14-
// Ask OpenAI for a streaming chat completion given the prompt
15-
const response = await openai.chat.completions.create({
16-
model: 'gpt-3.5-turbo',
17-
stream: true,
18-
messages,
19-
});
9+
const result = await experimental_streamText({
10+
model: openai.chat('gpt-4-turbo-preview'),
11+
messages,
12+
});
2013

21-
// Convert the response into a friendly text-stream
22-
const stream = OpenAIStream(response);
23-
// Respond with the stream
24-
return new StreamingTextResponse(stream);
14+
return new StreamingTextResponse(result.toAIStream());
15+
} catch (error) {
16+
console.error(error);
17+
throw error;
18+
}
2519
};

packages/core/react/use-chat.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ const getStreamedResponse = async (
8888
messagesRef: React.MutableRefObject<Message[]>,
8989
abortControllerRef: React.MutableRefObject<AbortController | null>,
9090
generateId: IdGenerator,
91+
streamMode?: 'stream-data' | 'text',
9192
onFinish?: (message: Message) => void,
9293
onResponse?: (response: Response) => void | Promise<void>,
9394
sendExtraMessageFields?: boolean,
@@ -179,6 +180,7 @@ const getStreamedResponse = async (
179180
tool_choice: chatRequest.tool_choice,
180181
}),
181182
},
183+
streamMode,
182184
credentials: extraMetadataRef.current.credentials,
183185
headers: {
184186
...extraMetadataRef.current.headers,
@@ -206,6 +208,7 @@ export function useChat({
206208
sendExtraMessageFields,
207209
experimental_onFunctionCall,
208210
experimental_onToolCall,
211+
streamMode,
209212
onResponse,
210213
onFinish,
211214
onError,
@@ -292,6 +295,7 @@ export function useChat({
292295
messagesRef,
293296
abortControllerRef,
294297
generateId,
298+
streamMode,
295299
onFinish,
296300
onResponse,
297301
sendExtraMessageFields,

packages/core/react/use-chat.ui.test.tsx

Lines changed: 160 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -9,145 +9,201 @@ import {
99
} from '../tests/utils/mock-fetch';
1010
import { useChat } from './use-chat';
1111

12-
const TestComponent = () => {
13-
const [id, setId] = React.useState<string>('first-id');
14-
const { messages, append, error, data, isLoading } = useChat({ id });
15-
16-
return (
17-
<div>
18-
<div data-testid="loading">{isLoading.toString()}</div>
19-
{error && <div data-testid="error">{error.toString()}</div>}
20-
{data && <div data-testid="data">{JSON.stringify(data)}</div>}
21-
{messages.map((m, idx) => (
22-
<div data-testid={`message-${idx}`} key={m.id}>
23-
{m.role === 'user' ? 'User: ' : 'AI: '}
24-
{m.content}
25-
</div>
26-
))}
27-
28-
<button
29-
data-testid="do-append"
30-
onClick={() => {
31-
append({ role: 'user', content: 'hi' });
32-
}}
33-
/>
34-
<button
35-
data-testid="do-change-id"
36-
onClick={() => {
37-
setId('second-id');
38-
}}
39-
/>
40-
</div>
41-
);
42-
};
43-
44-
beforeEach(() => {
45-
render(<TestComponent />);
46-
});
12+
describe('stream data stream', () => {
13+
const TestComponent = () => {
14+
const [id, setId] = React.useState<string>('first-id');
15+
const { messages, append, error, data, isLoading } = useChat({ id });
16+
17+
return (
18+
<div>
19+
<div data-testid="loading">{isLoading.toString()}</div>
20+
{error && <div data-testid="error">{error.toString()}</div>}
21+
{data && <div data-testid="data">{JSON.stringify(data)}</div>}
22+
{messages.map((m, idx) => (
23+
<div data-testid={`message-${idx}`} key={m.id}>
24+
{m.role === 'user' ? 'User: ' : 'AI: '}
25+
{m.content}
26+
</div>
27+
))}
28+
29+
<button
30+
data-testid="do-append"
31+
onClick={() => {
32+
append({ role: 'user', content: 'hi' });
33+
}}
34+
/>
35+
<button
36+
data-testid="do-change-id"
37+
onClick={() => {
38+
setId('second-id');
39+
}}
40+
/>
41+
</div>
42+
);
43+
};
4744

48-
afterEach(() => {
49-
vi.restoreAllMocks();
50-
cleanup();
51-
});
45+
beforeEach(() => {
46+
render(<TestComponent />);
47+
});
5248

53-
test('Shows streamed complex text response', async () => {
54-
mockFetchDataStream({
55-
url: 'https://example.com/api/chat',
56-
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
49+
afterEach(() => {
50+
vi.restoreAllMocks();
51+
cleanup();
5752
});
5853

59-
await userEvent.click(screen.getByTestId('do-append'));
54+
it('should show streamed response', async () => {
55+
mockFetchDataStream({
56+
url: 'https://example.com/api/chat',
57+
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
58+
});
6059

61-
await screen.findByTestId('message-0');
62-
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
60+
await userEvent.click(screen.getByTestId('do-append'));
6361

64-
await screen.findByTestId('message-1');
65-
expect(screen.getByTestId('message-1')).toHaveTextContent(
66-
'AI: Hello, world.',
67-
);
68-
});
62+
await screen.findByTestId('message-0');
63+
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
6964

70-
test('Shows streamed complex text response with data', async () => {
71-
mockFetchDataStream({
72-
url: 'https://example.com/api/chat',
73-
chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'],
65+
await screen.findByTestId('message-1');
66+
expect(screen.getByTestId('message-1')).toHaveTextContent(
67+
'AI: Hello, world.',
68+
);
7469
});
7570

76-
await userEvent.click(screen.getByTestId('do-append'));
71+
it('should show streamed response with data', async () => {
72+
mockFetchDataStream({
73+
url: 'https://example.com/api/chat',
74+
chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'],
75+
});
7776

78-
await screen.findByTestId('data');
79-
expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]');
77+
await userEvent.click(screen.getByTestId('do-append'));
8078

81-
await screen.findByTestId('message-1');
82-
expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello');
83-
});
79+
await screen.findByTestId('data');
80+
expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]');
8481

85-
test('Shows error response', async () => {
86-
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
82+
await screen.findByTestId('message-1');
83+
expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello');
84+
});
8785

88-
await userEvent.click(screen.getByTestId('do-append'));
86+
it('should show error response', async () => {
87+
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
8988

90-
// TODO bug? the user message does not show up
91-
// await screen.findByTestId('message-0');
92-
// expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
89+
await userEvent.click(screen.getByTestId('do-append'));
9390

94-
await screen.findByTestId('error');
95-
expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found');
96-
});
91+
// TODO bug? the user message does not show up
92+
// await screen.findByTestId('message-0');
93+
// expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
94+
95+
await screen.findByTestId('error');
96+
expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found');
97+
});
98+
99+
describe('loading state', () => {
100+
it('should show loading state', async () => {
101+
let finishGeneration: ((value?: unknown) => void) | undefined;
102+
const finishGenerationPromise = new Promise(resolve => {
103+
finishGeneration = resolve;
104+
});
97105

98-
describe('loading state', () => {
99-
test('should show loading state', async () => {
100-
let finishGeneration: ((value?: unknown) => void) | undefined;
101-
const finishGenerationPromise = new Promise(resolve => {
102-
finishGeneration = resolve;
106+
mockFetchDataStreamWithGenerator({
107+
url: 'https://example.com/api/chat',
108+
chunkGenerator: (async function* generate() {
109+
const encoder = new TextEncoder();
110+
yield encoder.encode('0:"Hello"\n');
111+
await finishGenerationPromise;
112+
})(),
113+
});
114+
115+
await userEvent.click(screen.getByTestId('do-append'));
116+
117+
await screen.findByTestId('loading');
118+
expect(screen.getByTestId('loading')).toHaveTextContent('true');
119+
120+
finishGeneration?.();
121+
122+
await findByText(await screen.findByTestId('loading'), 'false');
123+
expect(screen.getByTestId('loading')).toHaveTextContent('false');
103124
});
104125

105-
mockFetchDataStreamWithGenerator({
106-
url: 'https://example.com/api/chat',
107-
chunkGenerator: (async function* generate() {
108-
const encoder = new TextEncoder();
109-
yield encoder.encode('0:"Hello"\n');
110-
await finishGenerationPromise;
111-
})(),
126+
it('should reset loading state on error', async () => {
127+
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
128+
129+
await userEvent.click(screen.getByTestId('do-append'));
130+
131+
await screen.findByTestId('loading');
132+
expect(screen.getByTestId('loading')).toHaveTextContent('false');
112133
});
134+
});
113135

114-
await userEvent.click(screen.getByTestId('do-append'));
136+
describe('id', () => {
137+
it('should clear out messages when the id changes', async () => {
138+
mockFetchDataStream({
139+
url: 'https://example.com/api/chat',
140+
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
141+
});
115142

116-
await screen.findByTestId('loading');
117-
expect(screen.getByTestId('loading')).toHaveTextContent('true');
143+
await userEvent.click(screen.getByTestId('do-append'));
118144

119-
finishGeneration?.();
145+
await screen.findByTestId('message-1');
146+
expect(screen.getByTestId('message-1')).toHaveTextContent(
147+
'AI: Hello, world.',
148+
);
120149

121-
await findByText(await screen.findByTestId('loading'), 'false');
122-
expect(screen.getByTestId('loading')).toHaveTextContent('false');
150+
await userEvent.click(screen.getByTestId('do-change-id'));
151+
152+
expect(screen.queryByTestId('message-0')).not.toBeInTheDocument();
153+
});
123154
});
155+
});
124156

125-
test('should reset loading state on error', async () => {
126-
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
157+
describe('text stream', () => {
158+
const TestComponent = () => {
159+
const { messages, append } = useChat({
160+
streamMode: 'text',
161+
});
127162

128-
await userEvent.click(screen.getByTestId('do-append'));
163+
return (
164+
<div>
165+
{messages.map((m, idx) => (
166+
<div data-testid={`message-${idx}-text-stream`} key={m.id}>
167+
{m.role === 'user' ? 'User: ' : 'AI: '}
168+
{m.content}
169+
</div>
170+
))}
171+
172+
<button
173+
data-testid="do-append-text-stream"
174+
onClick={() => {
175+
append({ role: 'user', content: 'hi' });
176+
}}
177+
/>
178+
</div>
179+
);
180+
};
129181

130-
await screen.findByTestId('loading');
131-
expect(screen.getByTestId('loading')).toHaveTextContent('false');
182+
beforeEach(() => {
183+
render(<TestComponent />);
132184
});
133-
});
134185

135-
describe('id', () => {
136-
it('should clear out messages when the id changes', async () => {
186+
afterEach(() => {
187+
vi.restoreAllMocks();
188+
cleanup();
189+
});
190+
191+
it('should show streamed response', async () => {
137192
mockFetchDataStream({
138193
url: 'https://example.com/api/chat',
139-
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
194+
chunks: ['Hello', ',', ' world', '.'],
140195
});
141196

142-
await userEvent.click(screen.getByTestId('do-append'));
197+
await userEvent.click(screen.getByTestId('do-append-text-stream'));
143198

144-
await screen.findByTestId('message-1');
145-
expect(screen.getByTestId('message-1')).toHaveTextContent(
146-
'AI: Hello, world.',
199+
await screen.findByTestId('message-0-text-stream');
200+
expect(screen.getByTestId('message-0-text-stream')).toHaveTextContent(
201+
'User: hi',
147202
);
148203

149-
await userEvent.click(screen.getByTestId('do-change-id'));
150-
151-
expect(screen.queryByTestId('message-0')).not.toBeInTheDocument();
204+
await screen.findByTestId('message-1-text-stream');
205+
expect(screen.getByTestId('message-1-text-stream')).toHaveTextContent(
206+
'AI: Hello, world.',
207+
);
152208
});
153209
});

0 commit comments

Comments
 (0)