Skip to content

Commit 5e5ce0f

Browse files
committed
refactor: remove connect from providers
1 parent 52532c8 commit 5e5ce0f

File tree

7 files changed

+93
-90
lines changed

7 files changed

+93
-90
lines changed

mobius/internal/api/v1/chatcompletions.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package v1
22

33
import (
44
"context"
5+
"encoding/json"
56

67
"connectrpc.com/connect"
78
"github.com/missingstudio/studio/backend/internal/providers"
@@ -28,8 +29,18 @@ func (s *V1Handler) ChatCompletions(
2829
return nil, errors.NewInternalError("provider don't have chat Completion capabilities")
2930
}
3031

31-
data, err := chatCompletionProvider.ChatCompletion(ctx, req.Msg)
32+
payload, err := json.Marshal(req.Msg)
3233
if err != nil {
34+
return nil, err
35+
}
36+
37+
resp, err := chatCompletionProvider.ChatCompletion(ctx, payload)
38+
if err != nil {
39+
return nil, errors.New(err)
40+
}
41+
42+
data := &llmv1.CompletionResponse{}
43+
if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
3344
return nil, errors.New(err)
3445
}
3546

mobius/internal/providers/anyscale/anyscale.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,19 @@ package anyscale
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"fmt"
87
"net/http"
98

109
"github.com/missingstudio/studio/backend/pkg/requester"
11-
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1210
)
1311

14-
func (anyscale *AnyscaleProvider) ChatCompletion(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
15-
payload, err := json.Marshal(cr)
16-
if err != nil {
17-
return nil, err
18-
}
19-
12+
func (anyscale *AnyscaleProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
2013
client := requester.NewHTTPClient()
2114
requestURL := fmt.Sprintf("%s%s", anyscale.Config.BaseURL, anyscale.Config.ChatCompletions)
2215
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
2316

2417
req.Header.Add("Content-Type", "application/json")
2518
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", anyscale.APIKey))
2619

27-
resp, err := client.Do(req)
28-
if err != nil {
29-
return nil, err
30-
}
31-
32-
var data llmv1.CompletionResponse
33-
err = json.Unmarshal(resp, &data)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
return &data, nil
20+
return client.SendRequestRaw(req)
3921
}

mobius/internal/providers/base/base.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ package base
22

33
import (
44
"context"
5-
6-
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
5+
"net/http"
76
)
87

98
type ProviderConfig struct {
@@ -18,5 +17,5 @@ type ProviderInterface interface {
1817

1918
type ChatCompletionInterface interface {
2019
ProviderInterface
21-
ChatCompletion(context.Context, *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error)
20+
ChatCompletion(context.Context, []byte) (*http.Response, error)
2221
}

mobius/internal/providers/deepinfra/deepinfra.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,19 @@ package deepinfra
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"fmt"
87
"net/http"
98

109
"github.com/missingstudio/studio/backend/pkg/requester"
11-
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1210
)
1311

14-
func (deepinfra *DeepinfraProvider) ChatCompletion(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
15-
payload, err := json.Marshal(cr)
16-
if err != nil {
17-
return nil, err
18-
}
19-
12+
func (deepinfra *DeepinfraProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
2013
client := requester.NewHTTPClient()
2114
requestURL := fmt.Sprintf("%s%s", deepinfra.Config.BaseURL, deepinfra.Config.ChatCompletions)
2215
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
2316

2417
req.Header.Add("Content-Type", "application/json")
2518
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", deepinfra.APIKey))
2619

27-
resp, err := client.Do(req)
28-
if err != nil {
29-
return nil, err
30-
}
31-
32-
var data llmv1.CompletionResponse
33-
err = json.Unmarshal(resp, &data)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
return &data, nil
20+
return client.SendRequestRaw(req)
3921
}

mobius/internal/providers/openai/openai.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,19 @@ package openai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"fmt"
87
"net/http"
98

109
"github.com/missingstudio/studio/backend/pkg/requester"
11-
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1210
)
1311

14-
func (oai *OpenAIProvider) ChatCompletion(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
15-
payload, err := json.Marshal(cr)
16-
if err != nil {
17-
return nil, err
18-
}
19-
12+
func (oai *OpenAIProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
2013
client := requester.NewHTTPClient()
2114
requestURL := fmt.Sprintf("%s%s", oai.Config.BaseURL, oai.Config.ChatCompletions)
2215
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
2316

2417
req.Header.Add("Content-Type", "application/json")
2518
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", oai.APIKey))
2619

27-
resp, err := client.Do(req)
28-
if err != nil {
29-
return nil, err
30-
}
31-
32-
var data llmv1.CompletionResponse
33-
err = json.Unmarshal(resp, &data)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
return &data, nil
20+
return client.SendRequestRaw(req)
3921
}

mobius/internal/providers/togetherai/togetherai.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,19 @@ package togetherai
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
76
"fmt"
87
"net/http"
98

109
"github.com/missingstudio/studio/backend/pkg/requester"
11-
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1210
)
1311

14-
func (ta *TogetherAIProvider) ChatCompletion(ctx context.Context, cr *llmv1.CompletionRequest) (*llmv1.CompletionResponse, error) {
15-
payload, err := json.Marshal(cr)
16-
if err != nil {
17-
return nil, err
18-
}
19-
12+
func (ta *TogetherAIProvider) ChatCompletion(ctx context.Context, payload []byte) (*http.Response, error) {
2013
client := requester.NewHTTPClient()
2114
requestURL := fmt.Sprintf("%s%s", ta.Config.BaseURL, ta.Config.ChatCompletions)
2215
req, _ := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(payload))
2316

2417
req.Header.Add("Content-Type", "application/json")
2518
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", ta.APIKey))
2619

27-
resp, err := client.Do(req)
28-
if err != nil {
29-
return nil, err
30-
}
31-
32-
var data llmv1.CompletionResponse
33-
err = json.Unmarshal(resp, &data)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
return &data, nil
20+
return client.SendRequestRaw(req)
3921
}

mobius/pkg/requester/http.go

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package requester
22

33
import (
4+
"bytes"
5+
"encoding/json"
46
"fmt"
57
"io"
68
"net/http"
@@ -36,7 +38,7 @@ func NewHTTPClient(opts ...HTTPOption) *HTTPClient {
3638
}
3739
}
3840

39-
func (c *HTTPClient) Do(req *http.Request) ([]byte, error) {
41+
func (c *HTTPClient) SendRequest(req *http.Request, response any, outputResponse bool) (*http.Response, error) {
4042
if err := c.before(req); err != nil {
4143
return nil, err
4244
}
@@ -46,14 +48,77 @@ func (c *HTTPClient) Do(req *http.Request) ([]byte, error) {
4648
return nil, err
4749
}
4850

49-
defer resp.Body.Close()
50-
body, err := io.ReadAll(resp.Body)
51+
if !outputResponse {
52+
defer resp.Body.Close()
53+
}
54+
55+
if c.IsFailureStatusCode(resp) {
56+
return nil, fmt.Errorf("http status not 2xx: %d", resp.StatusCode)
57+
}
58+
59+
if outputResponse {
60+
var buf bytes.Buffer
61+
tee := io.TeeReader(resp.Body, &buf)
62+
err = DecodeResponse(tee, response)
63+
64+
resp.Body = io.NopCloser(&buf)
65+
} else {
66+
err = json.NewDecoder(resp.Body).Decode(response)
67+
}
68+
69+
if err != nil {
70+
return nil, err
71+
}
72+
73+
return resp, nil
74+
}
75+
76+
func (c *HTTPClient) SendRequestRaw(req *http.Request) (*http.Response, error) {
77+
if err := c.before(req); err != nil {
78+
return nil, err
79+
}
80+
81+
resp, err := c.client.Do(req)
5182
if err != nil {
5283
return nil, err
5384
}
5485

55-
if int(resp.StatusCode/100) != 2 {
56-
return nil, fmt.Errorf("http status not 2xx: %d %s", resp.StatusCode, string(body))
86+
if c.IsFailureStatusCode(resp) {
87+
return nil, fmt.Errorf("http status not 2xx: %d", resp.StatusCode)
88+
}
89+
90+
return resp, nil
91+
}
92+
93+
func (r *HTTPClient) IsFailureStatusCode(resp *http.Response) bool {
94+
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
95+
}
96+
97+
type Stringer interface {
98+
GetString() *string
99+
}
100+
101+
func DecodeResponse(body io.Reader, v any) error {
102+
if v == nil {
103+
return nil
104+
}
105+
106+
if result, ok := v.(*string); ok {
107+
return DecodeString(body, result)
108+
}
109+
110+
if stringer, ok := v.(Stringer); ok {
111+
return DecodeString(body, stringer.GetString())
112+
}
113+
114+
return json.NewDecoder(body).Decode(v)
115+
}
116+
117+
func DecodeString(body io.Reader, output *string) error {
118+
b, err := io.ReadAll(body)
119+
if err != nil {
120+
return err
57121
}
58-
return body, nil
122+
*output = string(b)
123+
return nil
59124
}

0 commit comments

Comments
 (0)