Skip to content

Commit ef2b619

Browse files
committed
refactor: contants and errors
1 parent 29b04b2 commit ef2b619

File tree

10 files changed

+106
-59
lines changed

10 files changed

+106
-59
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package constants
2+
3+
const (
4+
MIMEApplicationJSON = "application/json"
5+
MIMEApplicationProtobuf = "application/protobuf"
6+
MIMEOctetStream = "application/octet-stream"
7+
MIMEApplicationForm = "application/x-www-form-urlencoded"
8+
MIMEMultipartForm = "multipart/form-data"
9+
MIMETextEventStream = "text/event-stream"
10+
MIMETextPlain = "text/plain"
11+
MIMETextHTML = "text/html"
12+
MIMEKeepAlive = "keep-alive"
13+
)
14+
15+
const (
16+
HeaderAuthorization = "Authorization"
17+
HeaderCacheControl = "Cache-Control"
18+
)
19+
20+
const (
21+
XMSAPIKey = "X-MS-Api-Key"
22+
XMSProvider = "X-MS-Provider"
23+
XMSConfig = "X-MS-Config"
24+
XMSCache = "X-MS-Cache"
25+
XMSRequestId = "X-MS-Request-Id"
26+
XMSTraceId = "X-MS-Trace-Id"
27+
XMSRetryCount = "X-MS-Retry-count"
28+
)
29+
30+
const (
31+
XMSRetryAttemptCount = "X-MS-Retry-Attempt-count"
32+
XMSCacheStatus = "X-MS-Cache-Status"
33+
)

gateway/internal/errors/errors.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package errors
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/missingstudio/studio/backend/internal/constants"
7+
"github.com/missingstudio/studio/common/errors"
8+
)
9+
10+
var (
11+
ErrProviderHeaderNotExit = errors.NewBadRequest(fmt.Sprintf("%s header is required", constants.XMSProvider))
12+
ErrRequiredHeaderNotExit = errors.NewBadRequest(fmt.Sprintf("either %s or %s header is required", constants.XMSProvider, constants.XMSConfig))
13+
ErrRateLimitExceeded = errors.NewForbidden("rate limit exceeded")
14+
ErrUnauthenticated = errors.NewUnauthorized("unauthenticated")
15+
ErrProviderNotFound = errors.NewNotFound("provider is not found")
16+
)

gateway/internal/interceptor/auth.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@ import (
55
"log/slog"
66

77
"connectrpc.com/connect"
8+
"github.com/missingstudio/studio/backend/internal/constants"
9+
"github.com/missingstudio/studio/backend/internal/errors"
810
)
911

1012
// NewAPIKeyInterceptor returns interceptor which is checking if api key exits
1113
func NewAPIKeyInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc {
1214
return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc {
1315
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
14-
apiHeader := req.Header().Get("X-MS-API-KEY")
16+
apiHeader := req.Header().Get(constants.XMSAPIKey)
1517
if apiHeader == "" {
1618
logger.Info("request without api key",
1719
"api_key", apiHeader,
1820
"addr", req.Peer().Addr,
1921
"endpoint", req.Spec().Procedure)
20-
return nil, ErrUnauthenticated
22+
return nil, errors.ErrUnauthenticated
2123
}
2224

2325
return next(ctx, req)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package interceptor
2+
3+
import (
4+
"context"
5+
6+
"connectrpc.com/connect"
7+
"github.com/missingstudio/studio/backend/internal/constants"
8+
"github.com/missingstudio/studio/backend/internal/errors"
9+
"github.com/missingstudio/studio/backend/internal/providers"
10+
)
11+
12+
func ProviderInterceptor() connect.UnaryInterceptorFunc {
13+
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
14+
return connect.UnaryFunc(func(
15+
ctx context.Context,
16+
req connect.AnyRequest,
17+
) (connect.AnyResponse, error) {
18+
// Check if required headers are available
19+
provider := req.Header().Get(constants.XMSProvider)
20+
config := req.Header().Get(constants.XMSConfig)
21+
if provider == "" || config == "" {
22+
return nil, errors.ErrRequiredHeaderNotExit
23+
}
24+
25+
// Check if provider has registered of not
26+
_, ok := providers.ProviderFactories[provider]
27+
if !ok {
28+
return nil, errors.ErrProviderNotFound
29+
}
30+
31+
return next(ctx, req)
32+
})
33+
}
34+
return connect.UnaryInterceptorFunc(interceptor)
35+
}

gateway/internal/interceptor/interceptor.go

Lines changed: 0 additions & 13 deletions
This file was deleted.

gateway/internal/interceptor/provider.go

Lines changed: 0 additions & 24 deletions
This file was deleted.

gateway/internal/interceptor/ratelimit.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"connectrpc.com/connect"
7+
"github.com/missingstudio/studio/backend/internal/errors"
78
"github.com/missingstudio/studio/backend/internal/ratelimiter"
89
)
910

@@ -15,7 +16,7 @@ func RateLimiterInterceptor(rl *ratelimiter.RateLimiter) connect.UnaryIntercepto
1516
) (connect.AnyResponse, error) {
1617
key := "req_count"
1718
if !rl.Limiter.Validate(key) {
18-
return nil, ErrRateLimitExceeded
19+
return nil, errors.ErrRateLimitExceeded
1920
}
2021

2122
return next(ctx, req)

gateway/internal/providers/providers.go

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,41 @@ package providers
22

33
import (
44
"context"
5-
"fmt"
65
"net/http"
76

7+
"github.com/missingstudio/studio/backend/internal/constants"
8+
"github.com/missingstudio/studio/backend/internal/errors"
89
"github.com/missingstudio/studio/backend/internal/providers/anyscale"
910
"github.com/missingstudio/studio/backend/internal/providers/azure"
1011
"github.com/missingstudio/studio/backend/internal/providers/base"
1112
"github.com/missingstudio/studio/backend/internal/providers/deepinfra"
1213
"github.com/missingstudio/studio/backend/internal/providers/openai"
1314
"github.com/missingstudio/studio/backend/internal/providers/togetherai"
14-
"github.com/missingstudio/studio/common/errors"
15-
)
16-
17-
var (
18-
ErrProviderHeaderNotExit = errors.New(fmt.Errorf("x-ms-provider provider header not available"))
19-
ErrProviderNotFound = errors.NewNotFound("provider is not found")
2015
)
2116

2217
type ProviderFactory interface {
2318
Create(headers http.Header) (base.ProviderInterface, error)
2419
}
2520

26-
var providerFactories = make(map[string]ProviderFactory)
21+
var ProviderFactories = make(map[string]ProviderFactory)
2722

2823
func init() {
29-
providerFactories["openai"] = openai.OpenAIProviderFactory{}
30-
providerFactories["azure"] = azure.AzureProviderFactory{}
31-
providerFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
32-
providerFactories["deepinfra"] = deepinfra.DeepinfraProviderFactory{}
33-
providerFactories["togetherai"] = togetherai.TogetherAIProviderFactory{}
24+
ProviderFactories["openai"] = openai.OpenAIProviderFactory{}
25+
ProviderFactories["azure"] = azure.AzureProviderFactory{}
26+
ProviderFactories["anyscale"] = anyscale.AnyscaleProviderFactory{}
27+
ProviderFactories["deepinfra"] = deepinfra.DeepinfraProviderFactory{}
28+
ProviderFactories["togetherai"] = togetherai.TogetherAIProviderFactory{}
3429
}
3530

3631
func GetProvider(ctx context.Context, headers http.Header) (base.ProviderInterface, error) {
37-
providerName := headers.Get("x-ms-provider")
32+
providerName := headers.Get(constants.XMSProvider)
3833
if providerName == "" {
39-
return nil, ErrProviderHeaderNotExit
34+
return nil, errors.ErrProviderHeaderNotExit
4035
}
4136

42-
providerFactory, ok := providerFactories[providerName]
37+
providerFactory, ok := ProviderFactories[providerName]
4338
if !ok {
44-
return nil, ErrProviderNotFound
39+
return nil, errors.ErrProviderNotFound
4540
}
4641

4742
return providerFactory.Create(headers)

gateway/main_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"connectrpc.com/connect"
1111

1212
v1 "github.com/missingstudio/studio/backend/internal/api/v1"
13+
"github.com/missingstudio/studio/backend/internal/errors"
1314
"github.com/missingstudio/studio/backend/internal/interceptor"
1415
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
1516
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
@@ -54,7 +55,7 @@ func TestGatewayServer(t *testing.T) {
5455
_, err := client.ChatCompletions(context.Background(), req)
5556

5657
require.NotNil(t, err)
57-
assert.True(t, strings.Contains(err.Error(), interceptor.ErrProviderHeaderNotExit.Error()))
58+
assert.True(t, strings.Contains(err.Error(), errors.ErrRequiredHeaderNotExit.Error()))
5859
}
5960
})
6061
}

gateway/pkg/utils/headers.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"connectrpc.com/connect"
1212
"github.com/go-playground/validator/v10"
13+
"github.com/missingstudio/studio/backend/internal/constants"
1314
"github.com/missingstudio/studio/common/errors"
1415
)
1516

@@ -20,7 +21,7 @@ func isJSON(s string, v interface{}) bool {
2021
}
2122

2223
func UnmarshalConfigHeaders(header http.Header, v interface{}) error {
23-
msconfig := header.Get("x-ms-provider")
24+
msconfig := header.Get(constants.XMSProvider)
2425
if msconfig == "" && isJSON(msconfig, v) {
2526
return ErrGatewayConfigHeaderNotValid
2627
}

0 commit comments

Comments
 (0)