Skip to content

Commit 52532c8

Browse files
committed
feat: add connect stream package to create server stream
1 parent e30ed0b commit 52532c8

6 files changed

Lines changed: 240 additions & 14 deletions

File tree

mobius/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ require (
1414
connectrpc.com/vanguard v0.1.0
1515
github.com/MakeNowJust/heredoc v1.0.0
1616
github.com/go-playground/validator/v10 v10.17.0
17+
github.com/google/go-cmp v0.6.0
1718
github.com/mcuadros/go-defaults v1.2.0
1819
github.com/missingstudio/studio/common v0.0.0-00010101000000-000000000000
1920
github.com/missingstudio/studio/protos v0.0.0-00010101000000-000000000000
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package mock
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"sync/atomic"
8+
"testing"
9+
)
10+
11+
type MockStream[T any] struct {
12+
mu sync.Mutex
13+
ctx context.Context
14+
ch chan *T
15+
closed bool
16+
counter *uint32
17+
Messages []*T
18+
MessageMap map[string]int
19+
}
20+
21+
func NewMockStream[T any](t *testing.T) *MockStream[T] {
22+
t.Helper()
23+
var counter uint32
24+
return &MockStream[T]{
25+
ctx: context.Background(),
26+
ch: make(chan *T),
27+
closed: false,
28+
counter: &counter,
29+
Messages: make([]*T, 0),
30+
MessageMap: make(map[string]int),
31+
}
32+
}
33+
34+
func (m *MockStream[T]) Run() error {
35+
for {
36+
select {
37+
case data, ok := <-m.ch:
38+
if !ok {
39+
return fmt.Errorf("stream closed")
40+
}
41+
atomic.AddUint32(m.counter, 1)
42+
m.Messages = append(m.Messages, data)
43+
case <-m.ctx.Done():
44+
return m.ctx.Err()
45+
}
46+
}
47+
}
48+
49+
func (m *MockStream[T]) GetChannel() chan *T {
50+
return m.ch
51+
}
52+
53+
func (m *MockStream[T]) Send(data *T) {
54+
m.mu.Lock()
55+
defer m.mu.Unlock()
56+
if !m.closed {
57+
m.ch <- data
58+
}
59+
}
60+
61+
func (m *MockStream[T]) Close() {
62+
m.mu.Lock()
63+
defer m.mu.Unlock()
64+
if !m.closed {
65+
close(m.ch)
66+
}
67+
m.closed = true
68+
}
69+
70+
// SetCounter sets the counter used to count the number of messages sent.
71+
// Multiple streams can share the same counter to count the total number of
72+
// messages sent across all streams.
73+
func (m *MockStream[T]) SetCounter(counter *uint32) {
74+
m.counter = counter
75+
}

mobius/internal/stream/stream.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package stream
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
"connectrpc.com/connect"
9+
)
10+
11+
type StreamInterface[T any] interface {
12+
Send(data *T)
13+
Run() error
14+
Close()
15+
}
16+
17+
// Stream wraps a connect.ServerStream.
18+
type Stream[T any] struct {
19+
mu sync.Mutex
20+
// stream is the underlying connect stream
21+
// that does the actual transfer of data
22+
// between the server and a client
23+
stream *connect.ServerStream[T]
24+
// context is the context of the stream
25+
ctx context.Context
26+
// The channel that we listen to for any
27+
// new data that we need to send to the client.
28+
ch chan *T
29+
// closed is a flag that indicates whether
30+
// the stream has been closed.
31+
closed bool
32+
}
33+
34+
// newStream creates a new stream.
35+
func NewStream[T any](ctx context.Context, st *connect.ServerStream[T]) *Stream[T] {
36+
return &Stream[T]{
37+
stream: st,
38+
ctx: ctx,
39+
ch: make(chan *T),
40+
}
41+
}
42+
43+
// Close closes the stream.
44+
func (s *Stream[T]) Close() {
45+
s.mu.Lock()
46+
defer s.mu.Unlock()
47+
if !s.closed {
48+
close(s.ch)
49+
}
50+
s.closed = true
51+
}
52+
53+
// Run runs the stream.
54+
// Run will block until the stream is closed.
55+
func (s *Stream[T]) Run() error {
56+
defer s.Close()
57+
for {
58+
select {
59+
case <-s.ctx.Done():
60+
return s.ctx.Err()
61+
case data, ok := <-s.ch:
62+
if !ok {
63+
return connect.NewError(connect.CodeCanceled, fmt.Errorf("stream closed"))
64+
}
65+
if err := s.stream.Send(data); err != nil {
66+
return err
67+
}
68+
}
69+
}
70+
}
71+
72+
// Send sends data to this stream's connected client.
73+
func (s *Stream[T]) Send(data *T) {
74+
s.mu.Lock()
75+
defer s.mu.Unlock()
76+
if !s.closed {
77+
s.ch <- data
78+
}
79+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package stream
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
9+
"github.com/missingstudio/studio/backend/internal/mock"
10+
)
11+
12+
type Data struct {
13+
Msg string
14+
}
15+
16+
var messages = []*Data{
17+
{Msg: "Hello"},
18+
{Msg: "World"},
19+
{Msg: "Foo"},
20+
{Msg: "Bar"},
21+
{Msg: "Gandalf"},
22+
{Msg: "Frodo"},
23+
{Msg: "Bilbo"},
24+
{Msg: "Radagast"},
25+
{Msg: "Sauron"},
26+
{Msg: "Gollum"},
27+
}
28+
29+
func TestStream(t *testing.T) {
30+
var counter uint32
31+
stream := mock.NewMockStream[Data](t)
32+
stream.SetCounter(&counter)
33+
wg := sync.WaitGroup{}
34+
wg.Add(1)
35+
36+
go func() {
37+
defer wg.Done()
38+
err := stream.Run()
39+
t.Log(err)
40+
}()
41+
42+
for _, data := range messages {
43+
stream.Send(data)
44+
}
45+
46+
stream.Close()
47+
wg.Wait()
48+
49+
// A total of 10 messages should have been sent.
50+
if counter != 10 {
51+
fmt.Println(counter)
52+
t.Errorf("expected 10, got %d", counter)
53+
}
54+
55+
msgMsp := make(map[string]int)
56+
for _, data := range stream.Messages {
57+
msgMsp[data.Msg]++
58+
}
59+
60+
if len(stream.Messages) != 10 {
61+
t.Errorf("expected 10 messages, got %d", len(stream.Messages))
62+
}
63+
if diff := cmp.Diff(messages, stream.Messages); diff != "" {
64+
t.Errorf("expected %v, got %v: %s", messages, stream.Messages, diff)
65+
}
66+
}

protos/pkg/llm/service.pb.go

Lines changed: 15 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

protos/proto/llm/service.proto

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ service LLMService {
1010
option (google.api.http).post = "/v1/chat/completions";
1111
option (google.api.http).body = "*";
1212
}
13-
rpc StreamChatCompletions(CompletionRequest) returns (stream CompletionResponse) {}
13+
rpc StreamChatCompletions(CompletionRequest) returns (stream CompletionResponse) {
14+
option (google.api.http).post = "/v1/chat/completions:stream";
15+
option (google.api.http).body = "*";
16+
}
1417
}
1518

1619
enum FinishReason {

0 commit comments

Comments
 (0)