Skip to content

Commit

Permalink
feat(llm-bridge): add vllm provider (#1005)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao authored Feb 8, 2025
1 parent 825dd41 commit 3d9c03d
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/vertexai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/vllm"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/xai"
)

Expand Down Expand Up @@ -185,6 +186,8 @@ func registerAIProvider(aiConfig *ai.Config) error {
))
case "deepseek":
providerpkg.RegisterProvider(cerebras.NewProvider(provider["api_key"], provider["model"]))
case "vllm":
providerpkg.RegisterProvider(vllm.NewProvider(provider["api_endpoint"], provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
Expand Down
6 changes: 6 additions & 0 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,9 @@ bridge:
location:
model:
credentials_file:

vllm:
api_endpoint: http://127.0.0.1:8000/v1
api_key:
model: deepseek-ai/DeepSeek-R1-Distill-Qwen-14B

71 changes: 71 additions & 0 deletions pkg/bridge/ai/provider/vllm/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Package vllm is the vllm llm provider
package vllm

import (
"context"

_ "github.com/joho/godotenv/autoload"
"github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/metadata"

provider "github.com/yomorun/yomo/pkg/bridge/ai/provider"
)

// check if implements ai.Provider
var _ provider.LLMProvider = &Provider{}

// Provider is the provider for vllm
type Provider struct {
// vllm OpenAI compatibility api endpoint
APIEndpoint string
// APIKey is the API key for vllm
APIKey string
// Model is the model for vllm
// eg. "meta-llama/Llama-3.2-7B-Instruct"
Model string
client *openai.Client
}

// NewProvider creates a new vllm ai provider
func NewProvider(apiEndpoint string, apiKey string, model string) *Provider {
if apiEndpoint == "" {
apiEndpoint = "http://127.0.0.1:8000"
}
// vllm api endpoint is different from the default openai api endpoint, so we need to append "/v1" to the endpoint
if model == "" {
model = "meta-llama/Llama-3.2-7B-Instruct"
}

c := openai.DefaultConfig(apiKey)
c.BaseURL = apiEndpoint

return &Provider{
APIEndpoint: apiEndpoint,
APIKey: apiKey,
Model: model,
client: openai.NewClientWithConfig(c),
}
}

// Name returns the name of the provider
func (p *Provider) Name() string {
return "vllm"
}

// GetChatCompletions implements ai.LLMProvider.
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) {
if req.Model == "" {
req.Model = p.Model
}

return p.client.CreateChatCompletion(ctx, req)
}

// GetChatCompletionsStream implements ai.LLMProvider.
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) {
if req.Model == "" {
req.Model = p.Model
}

return p.client.CreateChatCompletionStream(ctx, req)
}
67 changes: 67 additions & 0 deletions pkg/bridge/ai/provider/vllm/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package vllm

import (
"context"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestVLlmProvider_Name(t *testing.T) {
provider := &Provider{}
name := provider.Name()

assert.Equal(t, "vllm", name)
}

func TestVLlmProvider_GetChatCompletions(t *testing.T) {
provider := NewProvider("", "", "")
msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
Model: "meta-llama/Llama-3.2-7B-Instruct",
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)
}

func TestVLlmProvider_GetChatCompletionsWithoutModel(t *testing.T) {
provider := NewProvider("", "", "")
msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)
}

0 comments on commit 3d9c03d

Please sign in to comment.