Skip to content

Commit 1087447

Browse files
committed
feat(gateway): add additional provider API support
- get provider by id - get provider configs Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
1 parent 84ef294 commit 1087447

17 files changed

Lines changed: 847 additions & 163 deletions

File tree

gateway/internal/api/v1/chatcompletions.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ func (s *V1Handler) GetChatCompletions(
4040
}
4141

4242
providerName := req.Header().Get(constants.XMSProvider)
43-
connectionObj := models.Connection{}
44-
connectionObj.Name = providerName
45-
connectionObj.Headers = headerConfig
46-
43+
connectionObj := models.Connection{
44+
Name: providerName,
45+
Headers: headerConfig,
46+
}
4747
provider, err := s.providerService.GetProvider(connectionObj)
4848
if err != nil {
4949
return nil, errors.New(err)
@@ -74,7 +74,9 @@ func (s *V1Handler) GetChatCompletions(
7474
}
7575

7676
ingesterdata := make(map[string]interface{})
77-
ingesterdata["provider"] = provider.Name()
77+
providerInfo := provider.Info()
78+
79+
ingesterdata["provider"] = providerInfo.Name
7880
ingesterdata["model"] = data.Model
7981
ingesterdata["latency"] = latency
8082
ingesterdata["total_tokens"] = *data.Usage.TotalTokens

gateway/internal/api/v1/models.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M
1717
continue
1818
}
1919

20-
providerName := provider.Name()
20+
providerInfo := provider.Info()
2121
providerModels := provider.Models()
22+
providerName := providerInfo.Name
2223

2324
var models []*llmv1.Model
2425
for _, val := range providerModels {

gateway/internal/api/v1/providers.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,69 @@ package v1
22

33
import (
44
"context"
5+
"encoding/json"
56

67
"connectrpc.com/connect"
8+
"github.com/missingstudio/studio/backend/models"
9+
"github.com/missingstudio/studio/common/errors"
710
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
811
"google.golang.org/protobuf/types/known/emptypb"
12+
"google.golang.org/protobuf/types/known/structpb"
913
)
1014

1115
func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ProvidersResponse], error) {
1216
providers := s.providerService.GetProviders()
1317

1418
data := []*llmv1.Provider{}
15-
for name := range providers {
19+
for _, provider := range providers {
20+
providerInfo := provider.Info()
1621
data = append(data, &llmv1.Provider{
17-
Name: name,
22+
Title: providerInfo.Title,
23+
Name: providerInfo.Name,
24+
Description: providerInfo.Description,
1825
})
1926
}
2027

2128
return connect.NewResponse(&llmv1.ProvidersResponse{
2229
Providers: data,
2330
}), nil
2431
}
32+
33+
func (s *V1Handler) GetProviderById(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) {
34+
provider, err := s.providerService.GetProvider(models.Connection{Name: req.Msg.Id})
35+
if err != nil {
36+
return nil, errors.NewNotFound(err.Error())
37+
}
38+
39+
info := provider.Info()
40+
p := &llmv1.Provider{
41+
Title: info.Title,
42+
Name: info.Name,
43+
Description: info.Description,
44+
}
45+
46+
return connect.NewResponse(&llmv1.GetProviderResponse{
47+
Provider: p,
48+
}), nil
49+
}
50+
51+
func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) {
52+
provider, err := s.providerService.GetProvider(models.Connection{Name: req.Msg.Id})
53+
if err != nil {
54+
return nil, errors.NewNotFound(err.Error())
55+
}
56+
57+
configs := map[string]any{}
58+
if err := json.Unmarshal(provider.Schema(), &configs); err != nil {
59+
return nil, errors.NewInternalError(err.Error())
60+
}
61+
62+
stConfigs, err := structpb.NewStruct(configs)
63+
if err != nil {
64+
return nil, errors.NewInternalError(err.Error())
65+
}
66+
67+
return connect.NewResponse(&llmv1.GetProviderConfigResponse{
68+
Config: stConfigs,
69+
}), nil
70+
}

gateway/internal/mock/mock_provider.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,22 @@ import "github.com/missingstudio/studio/backend/internal/providers/base"
55
var _ base.IProvider = &providerMock{}
66

77
type providerMock struct {
8-
name string
8+
info base.ProviderInfo
9+
config base.ProviderConfig
910
}
1011

1112
func NewProviderMock(name string) base.IProvider {
1213
return &providerMock{
13-
name: name,
14+
info: base.ProviderInfo{Name: name},
1415
}
1516
}
1617

17-
func (p providerMock) Name() string {
18-
return p.name
18+
func (p providerMock) Info() base.ProviderInfo {
19+
return p.info
20+
}
21+
22+
func (p providerMock) Config() base.ProviderConfig {
23+
return p.config
1924
}
2025

2126
func (p providerMock) Schema() []byte {

gateway/internal/providers/anyscale/base.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,32 @@ var schema []byte
1313
var _ base.IProvider = &anyscaleProvider{}
1414

1515
type anyscaleProvider struct {
16-
name string
16+
info base.ProviderInfo
1717
config base.ProviderConfig
1818
conn models.Connection
1919
}
2020

21-
func (anyscale anyscaleProvider) Name() string {
22-
return anyscale.name
21+
func (anyscale anyscaleProvider) Info() base.ProviderInfo {
22+
return anyscale.info
23+
}
24+
25+
func (anyscale anyscaleProvider) Config() base.ProviderConfig {
26+
return anyscale.config
2327
}
2428

2529
func (anyscale anyscaleProvider) Schema() []byte {
2630
return schema
2731
}
2832

33+
func getAnyscaleInfo() base.ProviderInfo {
34+
return base.ProviderInfo{
35+
Title: "Anyscale",
36+
Name: "anyscale",
37+
Description: `Anyscale Endpoints is a fast and scalable API to integrate OSS LLMs into your app.
38+
Use our growing list of high performance models or deploy your own.`,
39+
}
40+
}
41+
2942
func getAnyscaleConfig(baseURL string) base.ProviderConfig {
3043
return base.ProviderConfig{
3144
BaseURL: baseURL,
@@ -37,7 +50,7 @@ func init() {
3750
models.ProviderRegistry["anyscale"] = func(connection models.Connection) base.IProvider {
3851
config := getAnyscaleConfig("https://api.endpoints.anyscale.com")
3952
return &anyscaleProvider{
40-
name: "Anyscale",
53+
info: getAnyscaleInfo(),
4154
config: config,
4255
conn: connection,
4356
}

gateway/internal/providers/azure/base.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,36 @@ var schema []byte
1313
var _ base.IProvider = &azureProvider{}
1414

1515
type azureProvider struct {
16-
name string
16+
info base.ProviderInfo
1717
config base.ProviderConfig
1818
conn models.Connection
1919
}
2020

21-
func (az azureProvider) Name() string {
22-
return az.name
21+
func (anyscale azureProvider) Info() base.ProviderInfo {
22+
return anyscale.info
23+
}
24+
25+
func (az azureProvider) Config() base.ProviderConfig {
26+
return az.config
2327
}
2428

2529
func (az azureProvider) Schema() []byte {
2630
return schema
2731
}
2832

33+
func getAzureInfo() base.ProviderInfo {
34+
return base.ProviderInfo{
35+
Title: "Azure",
36+
Name: "azure",
37+
Description: "Azure OpenAI Service offers industry-leading coding and language AI models that you can fine-tune to your specific needs for a variety of use cases.",
38+
}
39+
}
40+
2941
func getAzureConfig() base.ProviderConfig {
3042
return base.ProviderConfig{
3143
BaseURL: "",
3244
ChatCompletions: "/chat/completions",
3345
}
3446
}
3547

36-
func init() {
37-
models.ProviderRegistry["azure"] = func(connection models.Connection) base.IProvider {
38-
config := getAzureConfig()
39-
return &azureProvider{
40-
name: "Azure",
41-
config: config,
42-
conn: connection,
43-
}
44-
}
45-
}
48+
func init() {}

gateway/internal/providers/base/base.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ type ProviderConfig struct {
99
BaseURL string
1010
ChatCompletions string
1111
}
12+
type ProviderInfo struct {
13+
Title string
14+
Name string
15+
Description string
16+
}
1217

1318
type IProvider interface {
14-
Name() string
19+
Info() ProviderInfo
20+
Config() ProviderConfig
1521
Models() []string
1622
Schema() []byte
1723
}

gateway/internal/providers/deepinfra/base.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,32 @@ var schema []byte
1313
var _ base.IProvider = &deepinfraProvider{}
1414

1515
type deepinfraProvider struct {
16-
name string
16+
info base.ProviderInfo
1717
config base.ProviderConfig
1818
conn models.Connection
1919
}
2020

21-
func (deepinfra deepinfraProvider) Name() string {
22-
return deepinfra.name
21+
func (anyscale deepinfraProvider) Info() base.ProviderInfo {
22+
return anyscale.info
23+
}
24+
25+
func (deepinfra deepinfraProvider) Config() base.ProviderConfig {
26+
return deepinfra.config
2327
}
2428

2529
func (deepinfra deepinfraProvider) Schema() []byte {
2630
return schema
2731
}
2832

33+
func getDeepinfraInfo() base.ProviderInfo {
34+
return base.ProviderInfo{
35+
Title: "Deepinfra",
36+
Name: "deepinfra",
37+
Description: `Deep Infra offers 100+ machine learning models from Text-to-Image, Object-Detection,
38+
Automatic-Speech-Recognition, Text-to-Text Generation, and more!`,
39+
}
40+
}
41+
2942
func getDeepinfraConfig(baseURL string) base.ProviderConfig {
3043
return base.ProviderConfig{
3144
BaseURL: baseURL,
@@ -37,7 +50,7 @@ func init() {
3750
models.ProviderRegistry["deepinfra"] = func(connection models.Connection) base.IProvider {
3851
config := getDeepinfraConfig("https://api.deepinfra.com/v1/openai")
3952
return &deepinfraProvider{
40-
name: "Deepinfra",
53+
info: getDeepinfraInfo(),
4154
config: config,
4255
conn: connection,
4356
}

gateway/internal/providers/openai/base.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,31 @@ var schema []byte
1313
var _ base.IProvider = &openAIProvider{}
1414

1515
type openAIProvider struct {
16-
name string
16+
info base.ProviderInfo
1717
config base.ProviderConfig
1818
conn models.Connection
1919
}
2020

21-
func (oai openAIProvider) Name() string {
22-
return oai.name
21+
func (anyscale openAIProvider) Info() base.ProviderInfo {
22+
return anyscale.info
23+
}
24+
25+
func (oai openAIProvider) Config() base.ProviderConfig {
26+
return oai.config
2327
}
2428

2529
func (oai openAIProvider) Schema() []byte {
2630
return schema
2731
}
2832

33+
func getOpenAIInfo() base.ProviderInfo {
34+
return base.ProviderInfo{
35+
Title: "OpenAI",
36+
Name: "openai",
37+
Description: `OpenAI API platform offers latest models and guides for safety best practices.`,
38+
}
39+
}
40+
2941
func getOpenAIConfig(baseURL string) base.ProviderConfig {
3042
return base.ProviderConfig{
3143
BaseURL: baseURL,
@@ -37,7 +49,7 @@ func init() {
3749
models.ProviderRegistry["openai"] = func(connection models.Connection) base.IProvider {
3850
config := getOpenAIConfig("https://api.openai.com")
3951
return &openAIProvider{
40-
name: "OpenAI",
52+
info: getOpenAIInfo(),
4153
config: config,
4254
conn: connection,
4355
}

gateway/internal/providers/togetherai/base.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,31 @@ var schema []byte
1313
var _ base.IProvider = &togetherAIProvider{}
1414

1515
type togetherAIProvider struct {
16-
name string
16+
info base.ProviderInfo
1717
config base.ProviderConfig
1818
conn models.Connection
1919
}
2020

21-
func (togetherAI togetherAIProvider) Name() string {
22-
return togetherAI.name
21+
func (anyscale togetherAIProvider) Info() base.ProviderInfo {
22+
return anyscale.info
23+
}
24+
25+
func (togetherAI togetherAIProvider) Config() base.ProviderConfig {
26+
return togetherAI.config
2327
}
2428

2529
func (togetherAI togetherAIProvider) Schema() []byte {
2630
return schema
2731
}
2832

33+
func getTogetherAIInfo() base.ProviderInfo {
34+
return base.ProviderInfo{
35+
Title: "Together AI",
36+
Name: "togetherai",
37+
Description: `Build gen AI models with Together AI. Benefit from the fastest and most cost-efficient tools and infra.`,
38+
}
39+
}
40+
2941
func getTogetherAIConfig(baseURL string) base.ProviderConfig {
3042
return base.ProviderConfig{
3143
BaseURL: baseURL,
@@ -37,7 +49,7 @@ func init() {
3749
models.ProviderRegistry["togetherai"] = func(connection models.Connection) base.IProvider {
3850
config := getTogetherAIConfig("https://api.together.xyz")
3951
return &togetherAIProvider{
40-
name: "Together AI",
52+
info: getTogetherAIInfo(),
4153
config: config,
4254
conn: connection,
4355
}

0 commit comments

Comments
 (0)