Skip to content

Commit 2d6c5a9

Browse files
committed
feat(gateway): add router for provider failover
1 parent 248cff3 commit 2d6c5a9

6 files changed

Lines changed: 198 additions & 0 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package mock
2+
3+
import "github.com/missingstudio/studio/backend/internal/providers/base"
4+
5+
type ProviderMock struct {
6+
Name string
7+
}
8+
9+
func NewProviderMock(name string) base.ProviderInterface {
10+
return &ProviderMock{
11+
Name: name,
12+
}
13+
}
14+
15+
func (p ProviderMock) GetName() string {
16+
return p.Name
17+
}
18+
19+
func (p ProviderMock) Validate() error {
20+
return nil
21+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package router
2+
3+
import (
4+
"sync/atomic"
5+
6+
"github.com/missingstudio/studio/backend/internal/providers/base"
7+
)
8+
9+
const (
10+
Priority Strategy = "priority"
11+
)
12+
13+
type PriorityRouter struct {
14+
idx *atomic.Uint64
15+
providers []base.ProviderInterface
16+
}
17+
18+
func NewPriorityRouter(providers []base.ProviderInterface) *PriorityRouter {
19+
return &PriorityRouter{
20+
idx: &atomic.Uint64{},
21+
providers: providers,
22+
}
23+
}
24+
25+
func (r *PriorityRouter) Iterator() ProviderIterator {
26+
return r
27+
}
28+
29+
func (r *PriorityRouter) Next() (base.ProviderInterface, error) {
30+
idx := int(r.idx.Load())
31+
32+
// Todo: make a check for healthy provider
33+
model := r.providers[idx]
34+
r.idx.Add(1)
35+
36+
return model, nil
37+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package router_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/missingstudio/studio/backend/internal/mock"
7+
"github.com/missingstudio/studio/backend/internal/providers/base"
8+
"github.com/missingstudio/studio/backend/internal/router"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestPriorityRouter(t *testing.T) {
13+
type Provider struct {
14+
Name string
15+
}
16+
17+
type TestCase struct {
18+
providers []Provider
19+
expectedModelIDs []string
20+
}
21+
22+
tests := map[string]TestCase{
23+
"openai": {[]Provider{{"openai"}, {"anyscale"}, {"azure"}}, []string{"openai", "anyscale", "azure"}},
24+
}
25+
26+
for name, tc := range tests {
27+
t.Run(name, func(t *testing.T) {
28+
providers := make([]base.ProviderInterface, 0, len(tc.providers))
29+
30+
for _, provider := range tc.providers {
31+
providers = append(providers, mock.NewProviderMock(provider.Name))
32+
}
33+
34+
routing := router.NewPriorityRouter(providers)
35+
iterator := routing.Iterator()
36+
37+
for _, modelID := range tc.expectedModelIDs {
38+
model, err := iterator.Next()
39+
require.NoError(t, err)
40+
require.Equal(t, modelID, model.GetName())
41+
}
42+
})
43+
}
44+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package router
2+
3+
import (
4+
"sync/atomic"
5+
6+
"github.com/missingstudio/studio/backend/internal/providers/base"
7+
)
8+
9+
const (
10+
RoundRobin Strategy = "roundrobin"
11+
)
12+
13+
type RoundRobinRouter struct {
14+
idx atomic.Uint64
15+
providers []base.ProviderInterface
16+
}
17+
18+
func NewRoundRobinRouter(providers []base.ProviderInterface) *RoundRobinRouter {
19+
return &RoundRobinRouter{
20+
providers: providers,
21+
}
22+
}
23+
24+
func (r *RoundRobinRouter) Iterator() ProviderIterator {
25+
return r
26+
}
27+
28+
func (r *RoundRobinRouter) Next() (base.ProviderInterface, error) {
29+
providerLen := len(r.providers)
30+
31+
// Todo: make a check for healthy provider
32+
idx := r.idx.Add(1) - 1
33+
model := r.providers[idx%uint64(providerLen)]
34+
35+
return model, nil
36+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package router_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/missingstudio/studio/backend/internal/mock"
7+
"github.com/missingstudio/studio/backend/internal/providers/base"
8+
"github.com/missingstudio/studio/backend/internal/router"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestRoundRobinRouter(t *testing.T) {
13+
type Provider struct {
14+
Name string
15+
}
16+
17+
type TestCase struct {
18+
providers []Provider
19+
expectedModelIDs []string
20+
}
21+
22+
tests := map[string]TestCase{
23+
"public llms": {[]Provider{{"openai"}, {"anyscale"}, {"azure"}}, []string{"openai", "anyscale", "azure"}},
24+
}
25+
26+
for name, tc := range tests {
27+
t.Run(name, func(t *testing.T) {
28+
providers := make([]base.ProviderInterface, 0, len(tc.providers))
29+
30+
for _, provider := range tc.providers {
31+
providers = append(providers, mock.NewProviderMock(provider.Name))
32+
}
33+
34+
routing := router.NewRoundRobinRouter(providers)
35+
iterator := routing.Iterator()
36+
37+
// loop three times over the whole pool to check if we return back to the begging of the list
38+
for _, providerName := range tc.expectedModelIDs {
39+
provider, err := iterator.Next()
40+
require.NoError(t, err)
41+
require.Equal(t, providerName, provider.GetName())
42+
}
43+
})
44+
}
45+
}

gateway/internal/router/router.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package router
2+
3+
import (
4+
"errors"
5+
6+
"github.com/missingstudio/studio/backend/internal/providers/base"
7+
)
8+
9+
var ErrNoHealthyProviders = errors.New("no healthy providers found")
10+
11+
type Strategy string
12+
13+
type ProviderIterator interface {
14+
Next() (base.ProviderInterface, error)
15+
}

0 commit comments

Comments
 (0)