@@ -2,46 +2,41 @@ package providers
22
33import (
44 "context"
5- "fmt"
65 "net/http"
76
7+ "github.com/missingstudio/studio/backend/internal/constants"
8+ "github.com/missingstudio/studio/backend/internal/errors"
89 "github.com/missingstudio/studio/backend/internal/providers/anyscale"
910 "github.com/missingstudio/studio/backend/internal/providers/azure"
1011 "github.com/missingstudio/studio/backend/internal/providers/base"
1112 "github.com/missingstudio/studio/backend/internal/providers/deepinfra"
1213 "github.com/missingstudio/studio/backend/internal/providers/openai"
1314 "github.com/missingstudio/studio/backend/internal/providers/togetherai"
14- "github.com/missingstudio/studio/common/errors"
15- )
16-
17- var (
18- ErrProviderHeaderNotExit = errors .New (fmt .Errorf ("x-ms-provider provider header not available" ))
19- ErrProviderNotFound = errors .NewNotFound ("provider is not found" )
2015)
2116
2217type ProviderFactory interface {
2318 Create (headers http.Header ) (base.ProviderInterface , error )
2419}
2520
26- var providerFactories = make (map [string ]ProviderFactory )
21+ var ProviderFactories = make (map [string ]ProviderFactory )
2722
2823func init () {
29- providerFactories ["openai" ] = openai.OpenAIProviderFactory {}
30- providerFactories ["azure" ] = azure.AzureProviderFactory {}
31- providerFactories ["anyscale" ] = anyscale.AnyscaleProviderFactory {}
32- providerFactories ["deepinfra" ] = deepinfra.DeepinfraProviderFactory {}
33- providerFactories ["togetherai" ] = togetherai.TogetherAIProviderFactory {}
24+ ProviderFactories ["openai" ] = openai.OpenAIProviderFactory {}
25+ ProviderFactories ["azure" ] = azure.AzureProviderFactory {}
26+ ProviderFactories ["anyscale" ] = anyscale.AnyscaleProviderFactory {}
27+ ProviderFactories ["deepinfra" ] = deepinfra.DeepinfraProviderFactory {}
28+ ProviderFactories ["togetherai" ] = togetherai.TogetherAIProviderFactory {}
3429}
3530
3631func GetProvider (ctx context.Context , headers http.Header ) (base.ProviderInterface , error ) {
37- providerName := headers .Get ("x-ms-provider" )
32+ providerName := headers .Get (constants . XMSProvider )
3833 if providerName == "" {
39- return nil , ErrProviderHeaderNotExit
34+ return nil , errors . ErrProviderHeaderNotExit
4035 }
4136
42- providerFactory , ok := providerFactories [providerName ]
37+ providerFactory , ok := ProviderFactories [providerName ]
4338 if ! ok {
44- return nil , ErrProviderNotFound
39+ return nil , errors . ErrProviderNotFound
4540 }
4641
4742 return providerFactory .Create (headers )
0 commit comments