Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extproc: remove the path from the translator factory #334

Merged
merged 5 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 65 additions & 55 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,42 +38,52 @@ type chatCompletionProcessor struct {
costs translator.LLMTokenUsage
}

// selectTranslator selects the translator based on the output schema.
func (c *chatCompletionProcessor) selectTranslator(out filterapi.VersionedAPISchema) error {
if c.translator != nil { // Prevents re-selection and allows translator injection in tests.
return nil
}
// TODO: currently, we ignore the LLMAPISchema."Version" field.
switch out.Name {
case filterapi.APISchemaOpenAI:
c.translator = translator.NewChatCompletionOpenAIToOpenAITranslator()
case filterapi.APISchemaAWSBedrock:
c.translator = translator.NewChatCompletionOpenAIToAWSBedrockTranslator()
default:
return fmt.Errorf("unsupported API schema: backend=%s", out)
}
return nil
}

// ProcessRequestHeaders implements [ProcessorIface.ProcessRequestHeaders].
func (p *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (c *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
// The request headers have already been at the time the processor was created
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}, nil
}

// ProcessRequestBody implements [ProcessorIface.ProcessRequestBody].
func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := p.requestHeaders[":path"]
model, body, err := p.config.bodyParser(path, rawBody)
func (c *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := c.requestHeaders[":path"]
model, body, err := c.config.bodyParser(path, rawBody)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
p.logger.Info("Processing request", "path", path, "model", model)
c.logger.Info("Processing request", "path", path, "model", model)

p.requestHeaders[p.config.modelNameHeaderKey] = model
b, err := p.config.router.Calculate(p.requestHeaders)
c.requestHeaders[c.config.modelNameHeaderKey] = model
b, err := c.config.router.Calculate(c.requestHeaders)
if err != nil {
return nil, fmt.Errorf("failed to calculate route: %w", err)
}
p.logger.Info("Selected backend", "backend", b.Name)
c.logger.Info("Selected backend", "backend", b.Name)

factory, ok := p.config.factories[b.Schema]
if !ok {
return nil, fmt.Errorf("failed to find factory for output schema %q", b.Schema)
}

t, err := factory(path)
if err != nil {
return nil, fmt.Errorf("failed to create translator: %w", err)
if err = c.selectTranslator(b.Schema); err != nil {
mathetake marked this conversation as resolved.
Show resolved Hide resolved
return nil, fmt.Errorf("failed to select translator: %w", err)
}
p.translator = t

headerMutation, bodyMutation, override, err := p.translator.RequestBody(body)
headerMutation, bodyMutation, override, err := c.translator.RequestBody(body)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
}
Expand All @@ -83,13 +93,13 @@ func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBod
}
// Set the model name to the request header with the key `x-ai-gateway-llm-model-name`.
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.modelNameHeaderKey, RawValue: []byte(model)},
Header: &corev3.HeaderValue{Key: c.config.modelNameHeaderKey, RawValue: []byte(model)},
}, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
Header: &corev3.HeaderValue{Key: c.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
})

if authHandler, ok := p.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(ctx, p.requestHeaders, headerMutation, bodyMutation); err != nil {
if authHandler, ok := c.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(ctx, c.requestHeaders, headerMutation, bodyMutation); err != nil {
return nil, fmt.Errorf("failed to do auth request: %w", err)
}
}
Expand All @@ -110,19 +120,19 @@ func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBod
}

// ProcessResponseHeaders implements [ProcessorIface.ProcessResponseHeaders].
func (p *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.responseHeaders = headersToMap(headers)
if enc := p.responseHeaders["content-encoding"]; enc != "" {
p.responseEncoding = enc
func (c *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
c.responseHeaders = headersToMap(headers)
if enc := c.responseHeaders["content-encoding"]; enc != "" {
c.responseEncoding = enc
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
if c.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{},
}}, nil
}
headerMutation, err := p.translator.ResponseHeaders(p.responseHeaders)
headerMutation, err := c.translator.ResponseHeaders(c.responseHeaders)
if err != nil {
return nil, fmt.Errorf("failed to transform response headers: %w", err)
}
Expand All @@ -134,9 +144,9 @@ func (p *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, head
}

// ProcessResponseBody implements [ProcessorIface.ProcessResponseBody].
func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (c *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
var br io.Reader
switch p.responseEncoding {
switch c.responseEncoding {
case "gzip":
br, err = gzip.NewReader(bytes.NewReader(body.Body))
if err != nil {
Expand All @@ -147,11 +157,11 @@ func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
if c.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}}, nil
}

headerMutation, bodyMutation, tokenUsage, err := p.translator.ResponseBody(p.responseHeaders, br, body.EndOfStream)
headerMutation, bodyMutation, tokenUsage, err := c.translator.ResponseBody(c.responseHeaders, br, body.EndOfStream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}
Expand All @@ -168,55 +178,55 @@ func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
}

// TODO: this is coupled with "LLM" specific logic. Once we have another use case, we need to refactor this.
p.costs.InputTokens += tokenUsage.InputTokens
p.costs.OutputTokens += tokenUsage.OutputTokens
p.costs.TotalTokens += tokenUsage.TotalTokens
if body.EndOfStream && len(p.config.requestCosts) > 0 {
resp.DynamicMetadata, err = p.maybeBuildDynamicMetadata()
c.costs.InputTokens += tokenUsage.InputTokens
c.costs.OutputTokens += tokenUsage.OutputTokens
c.costs.TotalTokens += tokenUsage.TotalTokens
if body.EndOfStream && len(c.config.requestCosts) > 0 {
resp.DynamicMetadata, err = c.maybeBuildDynamicMetadata()
if err != nil {
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
}
}
return resp, nil
}

func (p *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(p.config.requestCosts))
for i := range p.config.requestCosts {
c := &p.config.requestCosts[i]
func (c *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(c.config.requestCosts))
for i := range c.config.requestCosts {
rc := &c.config.requestCosts[i]
var cost uint32
switch c.Type {
switch rc.Type {
case filterapi.LLMRequestCostTypeInputToken:
cost = p.costs.InputTokens
cost = c.costs.InputTokens
case filterapi.LLMRequestCostTypeOutputToken:
cost = p.costs.OutputTokens
cost = c.costs.OutputTokens
case filterapi.LLMRequestCostTypeTotalToken:
cost = p.costs.TotalTokens
cost = c.costs.TotalTokens
case filterapi.LLMRequestCostTypeCELExpression:
costU64, err := llmcostcel.EvaluateProgram(
c.celProg,
p.requestHeaders[p.config.modelNameHeaderKey],
p.requestHeaders[p.config.selectedBackendHeaderKey],
p.costs.InputTokens,
p.costs.OutputTokens,
p.costs.TotalTokens,
rc.celProg,
c.requestHeaders[c.config.modelNameHeaderKey],
c.requestHeaders[c.config.selectedBackendHeaderKey],
c.costs.InputTokens,
c.costs.OutputTokens,
c.costs.TotalTokens,
)
if err != nil {
return nil, fmt.Errorf("failed to evaluate CEL expression: %w", err)
}
cost = uint32(costU64) //nolint:gosec
default:
return nil, fmt.Errorf("unknown request cost kind: %s", c.Type)
return nil, fmt.Errorf("unknown request cost kind: %s", rc.Type)
}
p.logger.Info("Setting request cost metadata", "type", c.Type, "cost", cost, "metadataKey", c.MetadataKey)
metadata[c.MetadataKey] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(cost)}}
c.logger.Info("Setting request cost metadata", "type", rc.Type, "cost", cost, "metadataKey", rc.MetadataKey)
metadata[rc.MetadataKey] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(cost)}}
}
if len(metadata) == 0 {
return nil, nil
}
return &structpb.Struct{
Fields: map[string]*structpb.Value{
p.config.metadataNamespace: {
c.config.metadataNamespace: {
Kind: &structpb.Value_StructValue{
StructValue: &structpb.Struct{Fields: metadata},
},
Expand Down
51 changes: 22 additions & 29 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ import (
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

func TestChatCompletion_SelectTranslator(t *testing.T) {
c := &chatCompletionProcessor{}
t.Run("unsupported", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: "Bar", Version: "v123"})
require.ErrorContains(t, err, "unsupported API schema: backend={Bar v123}")
})
t.Run("supported openai", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI})
require.NoError(t, err)
require.NotNil(t, c.translator)
})
t.Run("supported aws bedrock", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock})
require.NoError(t, err)
require.NotNil(t, c.translator)
})
}

func TestChatCompletion_ProcessRequestHeaders(t *testing.T) {
p := &chatCompletionProcessor{}
res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{
Expand Down Expand Up @@ -128,27 +146,9 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: make(map[filterapi.VersionedAPISchema]translator.Factory),
}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to find factory for output schema {\"some-schema\" \"v10.0\"}")
})
t.Run("translator factory error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
rt := mockRouter{
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retErr: errors.New("test error"), expPath: "/foo"}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to create translator: test error")
require.ErrorContains(t, err, "unsupported API schema: backend={some-schema v10.0}")
})
t.Run("translator error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
Expand All @@ -157,13 +157,10 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retTranslator: mockTranslator{t: t, retErr: errors.New("test error")}, expPath: "/foo"}
tr := mockTranslator{t: t, retErr: errors.New("test error")}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
}, requestHeaders: headers, logger: slog.Default()}
}, requestHeaders: headers, logger: slog.Default(), translator: tr}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to transform request: test error")
})
Expand All @@ -178,15 +175,11 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
headerMut := &extprocv3.HeaderMutation{}
bodyMut := &extprocv3.BodyMutation{}
mt := mockTranslator{t: t, expRequestBody: someBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut}
factory := mockTranslatorFactory{t: t, retTranslator: mt, expPath: "/foo"}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
selectedBackendHeaderKey: "x-ai-gateway-backend-key",
modelNameHeaderKey: "x-ai-gateway-model-key",
}, requestHeaders: headers, logger: slog.Default()}
}, requestHeaders: headers, logger: slog.Default(), translator: mt}
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.NoError(t, err)
require.Equal(t, mt, p.translator)
Expand Down
14 changes: 0 additions & 14 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ func (m *mockRequestBodyParser) impl(path string, body *extprocv3.HttpBody) (mod
return m.retModelName, m.retRb, m.retErr
}

// mockTranslatorFactory implements [translator.Factory] for testing.
type mockTranslatorFactory struct {
t *testing.T
expPath string
retTranslator translator.Translator
retErr error
}

// NewTranslator implements [translator.Factory].
func (m mockTranslatorFactory) impl(path string) (translator.Translator, error) {
require.Equal(m.t, m.expPath, path)
return m.retTranslator, m.retErr
}

// mockExternalProcessingStream implements [extprocv3.ExternalProcessor_ProcessServer] for testing.
type mockExternalProcessingStream struct {
t *testing.T
Expand Down
2 changes: 0 additions & 2 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
)

// processorConfig is the configuration for the processor.
Expand All @@ -22,7 +21,6 @@ type processorConfig struct {
bodyParser router.RequestBodyParser
router x.Router
modelNameHeaderKey, selectedBackendHeaderKey string
factories map[filterapi.VersionedAPISchema]translator.Factory
backendAuthHandlers map[string]backendauth.Handler
metadataNamespace string
requestCosts []processorConfigRequestCost
Expand Down
10 changes: 0 additions & 10 deletions internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

Expand Down Expand Up @@ -60,19 +59,11 @@ func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error
}

var (
factories = make(map[filterapi.VersionedAPISchema]translator.Factory)
backendAuthHandlers = make(map[string]backendauth.Handler)
declaredModels []string
)
for _, r := range config.Rules {
for _, b := range r.Backends {
if _, ok := factories[b.Schema]; !ok {
factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema)
if err != nil {
return fmt.Errorf("cannot create translator factory: %w", err)
}
}

if b.Auth != nil {
backendAuthHandlers[b.Name], err = backendauth.NewHandler(ctx, b.Auth)
if err != nil {
Expand Down Expand Up @@ -112,7 +103,6 @@ func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error
bodyParser: bodyParser, router: rt,
selectedBackendHeaderKey: config.SelectedBackendHeaderKey,
modelNameHeaderKey: config.ModelNameHeaderKey,
factories: factories,
backendAuthHandlers: backendAuthHandlers,
metadataNamespace: config.MetadataNamespace,
requestCosts: costs,
Expand Down
3 changes: 0 additions & 3 deletions internal/extproc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ func TestServer_LoadConfig(t *testing.T) {
require.NotNil(t, s.config.bodyParser)
require.Equal(t, "x-ai-eg-selected-backend", s.config.selectedBackendHeaderKey)
require.Equal(t, "x-model-name", s.config.modelNameHeaderKey)
require.Len(t, s.config.factories, 2)
require.NotNil(t, s.config.factories[filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}])
require.NotNil(t, s.config.factories[filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock}])

require.Len(t, s.config.requestCosts, 2)
require.Equal(t, filterapi.LLMRequestCostTypeOutputToken, s.config.requestCosts[0].Type)
Expand Down
Loading