-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(llm-bridge): add vllm provider (#1005)
- Loading branch information
1 parent
825dd41
commit 3d9c03d
Showing
4 changed files
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// Package vllm is the vllm llm provider | ||
package vllm | ||
|
||
import ( | ||
"context" | ||
|
||
_ "github.com/joho/godotenv/autoload" | ||
"github.com/sashabaranov/go-openai" | ||
"github.com/yomorun/yomo/core/metadata" | ||
|
||
provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" | ||
) | ||
|
||
// check if implements ai.Provider | ||
var _ provider.LLMProvider = &Provider{} | ||
|
||
// Provider is the provider for vllm | ||
type Provider struct { | ||
// vllm OpenAI compatibility api endpoint | ||
APIEndpoint string | ||
// APIKey is the API key for vllm | ||
APIKey string | ||
// Model is the model for vllm | ||
// eg. "meta-llama/Llama-3.2-7B-Instruct" | ||
Model string | ||
client *openai.Client | ||
} | ||
|
||
// NewProvider creates a new vllm ai provider | ||
func NewProvider(apiEndpoint string, apiKey string, model string) *Provider { | ||
if apiEndpoint == "" { | ||
apiEndpoint = "http://127.0.0.1:8000" | ||
} | ||
// vllm api endpoint is different from the default openai api endpoint, so we need to append "/v1" to the endpoint | ||
if model == "" { | ||
model = "meta-llama/Llama-3.2-7B-Instruct" | ||
} | ||
|
||
c := openai.DefaultConfig(apiKey) | ||
c.BaseURL = apiEndpoint | ||
|
||
return &Provider{ | ||
APIEndpoint: apiEndpoint, | ||
APIKey: apiKey, | ||
Model: model, | ||
client: openai.NewClientWithConfig(c), | ||
} | ||
} | ||
|
||
// Name returns the name of the provider | ||
func (p *Provider) Name() string { | ||
return "vllm" | ||
} | ||
|
||
// GetChatCompletions implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { | ||
if req.Model == "" { | ||
req.Model = p.Model | ||
} | ||
|
||
return p.client.CreateChatCompletion(ctx, req) | ||
} | ||
|
||
// GetChatCompletionsStream implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { | ||
if req.Model == "" { | ||
req.Model = p.Model | ||
} | ||
|
||
return p.client.CreateChatCompletionStream(ctx, req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package vllm | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestVLlmProvider_Name(t *testing.T) { | ||
provider := &Provider{} | ||
name := provider.Name() | ||
|
||
assert.Equal(t, "vllm", name) | ||
} | ||
|
||
func TestVLlmProvider_GetChatCompletions(t *testing.T) { | ||
provider := NewProvider("", "", "") | ||
msgs := []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: "hello", | ||
}, | ||
{ | ||
Role: "system", | ||
Content: "I'm a bot", | ||
}, | ||
} | ||
req := openai.ChatCompletionRequest{ | ||
Messages: msgs, | ||
Model: "meta-llama/Llama-3.2-7B-Instruct", | ||
} | ||
|
||
_, err := provider.GetChatCompletions(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
|
||
_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
} | ||
|
||
func TestVLlmProvider_GetChatCompletionsWithoutModel(t *testing.T) { | ||
provider := NewProvider("", "", "") | ||
msgs := []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: "hello", | ||
}, | ||
{ | ||
Role: "system", | ||
Content: "I'm a bot", | ||
}, | ||
} | ||
req := openai.ChatCompletionRequest{ | ||
Messages: msgs, | ||
} | ||
|
||
_, err := provider.GetChatCompletions(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
|
||
_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
} |