Skip to content

Commit 21d0642

Browse files
committed
feat: add requester pkg
1 parent aa0a6e8 commit 21d0642

8 files changed

Lines changed: 156 additions & 45 deletions

File tree

mobius/internal/connectrpc/mux.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package connectrpc
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"log"
78
"net/http"
@@ -12,6 +13,7 @@ import (
1213
"connectrpc.com/validate"
1314
"connectrpc.com/vanguard"
1415
"github.com/missingstudio/studio/backend/internal/providers"
16+
"github.com/missingstudio/studio/backend/internal/providers/base"
1517
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1618
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
1719
)
@@ -70,7 +72,12 @@ func (s *LLMServer) ChatCompletions(
7072
return nil, connect.NewError(connect.CodeInternal, err)
7173
}
7274

73-
data, err := provider.ChatCompilation(ctx, req.Msg)
75+
completionProvider, ok := provider.(base.ChatCompilationInterface)
76+
if !ok {
77+
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("method not implemented"))
78+
}
79+
80+
data, err := completionProvider.ChatCompilation(ctx, req.Msg)
7481
if err != nil {
7582
return nil, connect.NewError(connect.CodeInternal, err)
7683
}

mobius/internal/providers/base/base.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ type ProviderConfig struct {
1111
ChatCompletions string
1212
}
1313

14-
type LLMProvider interface {
15-
ChatCompilation(ctx context.Context, ra *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
14+
type ProviderInterface interface{}
15+
16+
type ChatCompilationInterface interface {
17+
ProviderInterface
18+
ChatCompilation(context.Context, *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
1619
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package openai
2+
3+
import "github.com/missingstudio/studio/backend/internal/providers/base"
4+
5+
type OpenAIProvider struct {
6+
APIKey string
7+
Config base.ProviderConfig
8+
}
9+
10+
func NewOpenAIProvider(token string, baseURL string) *OpenAIProvider {
11+
config := getOpenAIConfig(baseURL)
12+
return &OpenAIProvider{
13+
APIKey: token,
14+
Config: config,
15+
}
16+
}
17+
18+
type OpenAIProviderFactory struct{}
19+
20+
func (f OpenAIProviderFactory) Create(token string) base.ProviderInterface {
21+
openAIProvider := NewOpenAIProvider(token, "https://api.openai.com")
22+
return openAIProvider
23+
}
24+
25+
func getOpenAIConfig(baseURL string) base.ProviderConfig {
26+
return base.ProviderConfig{
27+
BaseURL: baseURL,
28+
ChatCompletions: "/v1/chat/completions",
29+
}
30+
}

mobius/internal/providers/openai/openai.go

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,19 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"io"
87
"net/http"
98

10-
"github.com/missingstudio/studio/backend/internal/providers/base"
9+
"github.com/missingstudio/studio/backend/pkg/requester"
1110
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1211
)
1312

14-
type OpenAIProviderFactory struct{}
15-
16-
func (f OpenAIProviderFactory) Create(token string) base.LLMProvider {
17-
openAIProvider := NewOpenAIProvider(token, "https://api.openai.com")
18-
return openAIProvider
19-
}
20-
21-
type OpenAIProvider struct {
22-
APIKey string
23-
Config base.ProviderConfig
24-
}
25-
26-
func NewOpenAIProvider(token string, baseURL string) *OpenAIProvider {
27-
config := getOpenAIConfig(baseURL)
28-
return &OpenAIProvider{
29-
APIKey: token,
30-
Config: config,
31-
}
32-
}
33-
34-
func getOpenAIConfig(baseURL string) base.ProviderConfig {
35-
return base.ProviderConfig{
36-
BaseURL: baseURL,
37-
ChatCompletions: "/v1/chat/completions",
38-
}
39-
}
40-
41-
func (oai OpenAIProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
13+
func (oai *OpenAIProvider) ChatCompilation(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
4214
payload, err := json.Marshal(cr)
4315
if err != nil {
4416
return nil, err
4517
}
4618

47-
client := &http.Client{}
19+
client := requester.NewHTTPClient()
4820
req, _ := http.NewRequestWithContext(ctx, "POST", oai.Config.BaseURL+oai.Config.ChatCompletions, bytes.NewReader(payload))
4921
req.Header.Add("Content-Type", "application/json")
5022
req.Header.Add("Authorization", "Bearer "+oai.APIKey)
@@ -53,15 +25,9 @@ func (oai OpenAIProvider) ChatCompilation(ctx context.Context, cr *llmv1.Complet
5325
if err != nil {
5426
return nil, err
5527
}
56-
defer resp.Body.Close()
57-
58-
body, err := io.ReadAll(resp.Body)
59-
if err != nil {
60-
return nil, err
61-
}
6228

6329
var data llmv1.CompletionResponse
64-
err = json.Unmarshal(body, &data)
30+
err = json.Unmarshal(resp, &data)
6531
if err != nil {
6632
return nil, err
6733
}

mobius/internal/providers/providers.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
type ProviderFactory interface {
14-
Create(token string) base.LLMProvider
14+
Create(token string) base.ProviderInterface
1515
}
1616

1717
var providerFactories = make(map[string]ProviderFactory)
@@ -20,9 +20,9 @@ func init() {
2020
providerFactories["openai"] = openai.OpenAIProviderFactory{}
2121
}
2222

23-
func GetProvider(headers http.Header) (base.LLMProvider, error) {
24-
provider := headers.Get("x-ms-provider")
25-
providerFactory, ok := providerFactories[provider]
23+
func GetProvider(headers http.Header) (base.ProviderInterface, error) {
24+
providerType := headers.Get("x-ms-provider")
25+
providerFactory, ok := providerFactories[providerType]
2626
if !ok {
2727
return nil, connect.NewError(connect.CodeNotFound, errors.New("provider not found"))
2828
}

mobius/pkg/requester/encode.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package requester
2+
3+
import (
4+
"encoding/json"
5+
)
6+
7+
type (
8+
Encoder func(any) ([]byte, error)
9+
Decoder func([]byte, any) error
10+
)
11+
12+
var (
13+
defaultEncoder = json.Marshal
14+
defaultDecoder = json.Unmarshal
15+
)

mobius/pkg/requester/http.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package requester
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"time"
8+
)
9+
10+
const defaultHTTPTimeout = 5
11+
12+
// HTTPClient represents a client to send HTTP requests.
13+
type HTTPClient struct {
14+
client *http.Client
15+
16+
// encoder is used to encode request bodies
17+
encoder Encoder
18+
19+
// decoder is used to decode response bodies
20+
decoder Decoder
21+
22+
// before is a function called before each
23+
// request is made. useful for like, auth sigs, etc.
24+
before func(*http.Request) error
25+
}
26+
27+
// NewHTTPClient is used to build a new HTTPClient.
28+
func NewHTTPClient(opts ...HTTPOption) *HTTPClient {
29+
return &HTTPClient{
30+
encoder: defaultEncoder,
31+
decoder: defaultDecoder,
32+
client: &http.Client{
33+
Timeout: time.Second * time.Duration(defaultHTTPTimeout),
34+
},
35+
before: func(_ *http.Request) error { return nil },
36+
}
37+
}
38+
39+
func (c *HTTPClient) Do(req *http.Request) ([]byte, error) {
40+
if err := c.before(req); err != nil {
41+
return nil, err
42+
}
43+
44+
resp, err := c.client.Do(req)
45+
if err != nil {
46+
return nil, err
47+
}
48+
49+
defer resp.Body.Close()
50+
body, err := io.ReadAll(resp.Body)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
if int(resp.StatusCode/100) != 2 {
56+
return nil, fmt.Errorf("http status not 2xx: %d %s", resp.StatusCode, string(body))
57+
}
58+
return body, nil
59+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package requester
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
type HTTPOption func(*HTTPClient)
8+
9+
func WithClient(c *http.Client) HTTPOption {
10+
return func(client *HTTPClient) {
11+
client.client = c
12+
}
13+
}
14+
15+
func WithEncoder(fn func(obj any) ([]byte, error)) HTTPOption {
16+
return func(c *HTTPClient) {
17+
c.encoder = fn
18+
}
19+
}
20+
21+
func WithDecoder(fn func([]byte, any) error) HTTPOption {
22+
return func(c *HTTPClient) {
23+
c.decoder = fn
24+
}
25+
}
26+
27+
func WithBefore(fn func(*http.Request) error) HTTPOption {
28+
return func(c *HTTPClient) {
29+
c.before = fn
30+
}
31+
}

0 commit comments

Comments
 (0)