Skip to content

Commit

Permalink
#44 Covered RR routing by tests & init routing based on config value …
Browse files Browse the repository at this point in the history
…& created interfaces for Model & LanguageModels
  • Loading branch information
roma-glushko committed Jan 14, 2024
1 parent 5136e35 commit 00499bf
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 26 deletions.
10 changes: 10 additions & 0 deletions pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ type LangModelProvider interface {
Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error)
}

type Model interface {
ID() string
Healthy() bool
}

type LanguageModel interface {
Model
LangModelProvider
}

// LangModel
type LangModel struct {
modelID string
Expand Down
20 changes: 20 additions & 0 deletions pkg/providers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,23 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.UnifiedChatRequest) (*
func (c *ProviderMock) Provider() string {
return "provider_mock"
}

type LangModelMock struct {
modelID string
healthy bool
}

func NewLangModelMock(ID string, healthy bool) *LangModelMock {
return &LangModelMock{
modelID: ID,
healthy: healthy,
}
}

func (m *LangModelMock) ID() string {
return m.modelID
}

func (m *LangModelMock) Healthy() bool {
return m.healthy
}
22 changes: 20 additions & 2 deletions pkg/routers/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package routers

import (
"fmt"

"glide/pkg/providers"
"glide/pkg/routers/retry"
"glide/pkg/routers/routing"
Expand Down Expand Up @@ -53,10 +55,10 @@ type LangRouterConfig struct {
}

// BuildModels creates LanguageModel slice out of the given config
func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LangModel, error) {
func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.LanguageModel, error) {
var errs error

models := make([]*providers.LangModel, 0, len(c.Models))
models := make([]providers.LanguageModel, 0, len(c.Models))

for _, modelConfig := range c.Models {
if !modelConfig.Enabled {
Expand Down Expand Up @@ -102,6 +104,22 @@ func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry {
)
}

func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routing.LangModelRouting, error) {
m := make([]providers.Model, 0, len(models))
for _, model := range models {
m = append(m, model)
}

switch c.RoutingStrategy {
case routing.Priority:
return routing.NewPriorityRouting(m), nil
case routing.RoundRobin:
return routing.NewRoundRobinRouting(m), nil
}

return nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy)
}

func DefaultLangRouterConfig() LangRouterConfig {
return LangRouterConfig{
Enabled: true,
Expand Down
17 changes: 12 additions & 5 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type LangRouter struct {
Config *LangRouterConfig
routing routing.LangModelRouting
retry *retry.ExpRetry
models []*providers.LangModel
models []providers.LanguageModel
telemetry *telemetry.Telemetry
}

Expand All @@ -34,12 +34,17 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter
return nil, err
}

routing, err := cfg.BuildRouting(models)
if err != nil {
return nil, err
}

router := &LangRouter{
routerID: cfg.ID,
Config: cfg,
models: models,
retry: cfg.BuildRetry(),
routing: routing.NewPriorityRouting(models),
routing: routing,
telemetry: tel,
}

Expand Down Expand Up @@ -68,13 +73,15 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatReque
break
}

resp, err := model.Chat(ctx, request)
langModel := model.(providers.LanguageModel)

resp, err := langModel.Chat(ctx, request)
if err != nil {
r.telemetry.Logger.Warn(
"lang model failed processing chat request",
zap.String("routerID", r.ID()),
zap.String("modelID", model.ID()),
zap.String("provider", model.Provider()),
zap.String("modelID", langModel.ID()),
zap.String("provider", langModel.Provider()),
zap.Error(err),
)

Expand Down
45 changes: 35 additions & 10 deletions pkg/routers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
models := []*providers.LangModel{
langModels := []providers.LanguageModel{
providers.NewLangModel(
"first",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
Expand All @@ -31,12 +31,17 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

router := LangRouter{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
models: models,
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}

Expand All @@ -54,7 +59,7 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {

func TestLangRouter_Priority_PickSecondHealthy(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
models := []*providers.LangModel{
langModels := []providers.LanguageModel{
providers.NewLangModel(
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
Expand All @@ -67,14 +72,19 @@ func TestLangRouter_Priority_PickSecondHealthy(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

expectedModels := []string{"second", "first"}

router := LangRouter{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
models: models,
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}

Expand All @@ -92,7 +102,7 @@ func TestLangRouter_Priority_PickSecondHealthy(t *testing.T) {

func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
models := []*providers.LangModel{
langModels := []providers.LanguageModel{
providers.NewLangModel(
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
Expand All @@ -105,12 +115,17 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

router := LangRouter{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
models: models,
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}

Expand All @@ -123,7 +138,7 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {

func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
budget := health.NewErrorBudget(1, health.MIN)
models := []*providers.LangModel{
langModels := []providers.LanguageModel{
providers.NewLangModel(
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}),
Expand All @@ -136,12 +151,17 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

router := LangRouter{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
models: models,
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}

Expand All @@ -156,7 +176,7 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {

func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
models := []*providers.LangModel{
langModels := []providers.LanguageModel{
providers.NewLangModel(
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
Expand All @@ -169,12 +189,17 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
}

router := LangRouter{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
models: models,
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/routers/routing/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ const (
// Priority of models are defined as position of the model on the list
// (e.g. the first model definition has the highest priority, then the second model definition and so on)
type PriorityRouting struct {
models []*providers.LangModel
models []providers.Model
}

func NewPriorityRouting(models []*providers.LangModel) *PriorityRouting {
func NewPriorityRouting(models []providers.Model) *PriorityRouting {
return &PriorityRouting{
models: models,
}
Expand All @@ -35,10 +35,10 @@ func (r *PriorityRouting) Iterator() LangModelIterator {

type PriorityIterator struct {
idx *atomic.Uint64
models []*providers.LangModel
models []providers.Model
}

func (r PriorityIterator) Next() (*providers.LangModel, error) {
func (r PriorityIterator) Next() (providers.Model, error) {
models := r.models
idx := r.idx.Load()

Expand Down
8 changes: 4 additions & 4 deletions pkg/routers/routing/round_robin.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ const (
// RoundRobinRouting routes request to the next model in the list in cycle
type RoundRobinRouting struct {
idx atomic.Uint64
models []*providers.LangModel
models []providers.Model
}

func NewRoundRobinRouting(models []*providers.LangModel) *RoundRobinRouting {
func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting {
return &RoundRobinRouting{
models: models,
}
Expand All @@ -26,13 +26,13 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator {
return r
}

func (r *RoundRobinRouting) Next() (*providers.LangModel, error) {
func (r *RoundRobinRouting) Next() (providers.Model, error) {
modelLen := len(r.models)

// in order to avoid infinite loop in case of no healthy model is available,
// we need to track whether we made a whole cycle around the model slice looking for a healthy model
for i := 0; i < modelLen; i++ {
idx := r.idx.Add(1)
idx := r.idx.Add(1) - 1
model := r.models[idx%uint64(modelLen)]

if !model.Healthy() {
Expand Down
63 changes: 63 additions & 0 deletions pkg/routers/routing/round_robin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package routing

import (
"testing"

"github.com/stretchr/testify/require"
"glide/pkg/providers"
)

func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
type Model struct {
modelID string
healthy bool
}

type TestCase struct {
models []Model
expectedModelIDs []string
}

tests := map[string]TestCase{
"all healthy": {[]Model{{"first", true}, {"second", true}, {"third", true}}, []string{"first", "second", "third"}},
"unhealthy in the middle": {[]Model{{"first", true}, {"second", false}, {"third", true}}, []string{"first", "third"}},
"two unhealthy": {[]Model{{"first", true}, {"second", false}, {"third", false}}, []string{"first"}},
"first unhealthy": {[]Model{{"first", false}, {"second", true}, {"third", true}}, []string{"second", "third"}},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy))
}

routing := NewRoundRobinRouting(models)
iterator := routing.Iterator()

for i := 0; i < 3; i++ {
// loop three times over the whole pool to check if we return back to the begging of the list
for _, modelID := range tc.expectedModelIDs {
model, err := iterator.Next()
require.NoError(t, err)
require.Equal(t, modelID, model.ID())
}
}
})
}
}

func TestRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
providers.NewLangModelMock("first", false),
providers.NewLangModelMock("second", false),
providers.NewLangModelMock("third", false),
}

routing := NewRoundRobinRouting(models)
iterator := routing.Iterator()

_, err := iterator.Next()
require.Error(t, err)
}
Loading

0 comments on commit 00499bf

Please sign in to comment.