-
Notifications
You must be signed in to change notification settings - Fork 853
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3db25b0
commit e90a278
Showing
5 changed files
with
444 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
//go:build go1.18 | ||
// +build go1.18 | ||
|
||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See License.txt in the project root for license information. | ||
|
||
package azopenai | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"os" | ||
"testing" | ||
|
||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" | ||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/go-cmp/cmp/cmpopts" | ||
) | ||
|
||
var ( | ||
endpoint = os.Getenv("AOAI_ENDPOINT") | ||
apiKey = os.Getenv("AOAI_API_KEY") | ||
) | ||
|
||
func TestClient_GetChatCompletions(t *testing.T) { | ||
type args struct { | ||
ctx context.Context | ||
deploymentID string | ||
body ChatCompletionsOptions | ||
options *ClientGetChatCompletionsOptions | ||
} | ||
cred := KeyCredential{APIKey: apiKey} | ||
deploymentID := "gpt-35-turbo" | ||
chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) | ||
if err != nil { | ||
log.Fatalf("%v", err) | ||
} | ||
tests := []struct { | ||
name string | ||
client *Client | ||
args args | ||
want ClientGetChatCompletionsResponse | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "ChatCompletions", | ||
client: chatClient, | ||
args: args{ | ||
ctx: context.TODO(), | ||
deploymentID: "gpt-35-turbo", | ||
body: ChatCompletionsOptions{ | ||
Messages: []*ChatMessage{ | ||
{ | ||
Role: to.Ptr(ChatRole("user")), | ||
Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."), | ||
}, | ||
}, | ||
MaxTokens: to.Ptr(int32(1024)), | ||
Temperature: to.Ptr(float32(0.0)), | ||
}, | ||
options: nil, | ||
}, | ||
want: ClientGetChatCompletionsResponse{ | ||
ChatCompletions: ChatCompletions{ | ||
Choices: []*ChatChoice{ | ||
{ | ||
Message: &ChatChoiceMessage{ | ||
Role: to.Ptr(ChatRole("assistant")), | ||
Content: to.Ptr("1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100."), | ||
}, | ||
Index: to.Ptr(int32(0)), | ||
FinishReason: to.Ptr(CompletionsFinishReason("stop")), | ||
}, | ||
}, | ||
Usage: &CompletionsUsage{ | ||
CompletionTokens: to.Ptr(int32(299)), | ||
PromptTokens: to.Ptr(int32(37)), | ||
TotalTokens: to.Ptr(int32(336)), | ||
}, | ||
}, | ||
}, | ||
wantErr: false, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := tt.client.GetChatCompletions(tt.args.ctx, tt.args.body, tt.args.options) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("Client.GetChatCompletions() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
opts := cmpopts.IgnoreFields(ChatCompletions{}, "Created", "ID") | ||
if diff := cmp.Diff(tt.want.ChatCompletions, got.ChatCompletions, opts); diff != "" { | ||
t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestClient_GetCompletions(t *testing.T) { | ||
type args struct { | ||
ctx context.Context | ||
deploymentID string | ||
body CompletionsOptions | ||
options *ClientGetCompletionsOptions | ||
} | ||
cred := KeyCredential{APIKey: apiKey} | ||
deploymentID := "text-davinci-003" | ||
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) | ||
if err != nil { | ||
log.Fatalf("%v", err) | ||
} | ||
tests := []struct { | ||
name string | ||
client *Client | ||
args args | ||
want ClientGetCompletionsResponse | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "chatbot", | ||
client: client, | ||
args: args{ | ||
ctx: context.TODO(), | ||
deploymentID: deploymentID, | ||
body: CompletionsOptions{ | ||
Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, | ||
MaxTokens: to.Ptr(int32(2048 - 127)), | ||
Temperature: to.Ptr(float32(0.0)), | ||
}, | ||
options: nil, | ||
}, | ||
want: ClientGetCompletionsResponse{ | ||
Completions: Completions{ | ||
Choices: []*Choice{ | ||
{ | ||
Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."), | ||
Index: to.Ptr(int32(0)), | ||
FinishReason: to.Ptr(CompletionsFinishReason("stop")), | ||
Logprobs: nil, | ||
}, | ||
}, | ||
Usage: &CompletionsUsage{ | ||
CompletionTokens: to.Ptr(int32(85)), | ||
PromptTokens: to.Ptr(int32(6)), | ||
TotalTokens: to.Ptr(int32(91)), | ||
}, | ||
}, | ||
}, | ||
wantErr: false, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
opts := cmpopts.IgnoreFields(Completions{}, "Created", "ID") | ||
if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" { | ||
t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestClient_GetEmbeddings(t *testing.T) { | ||
type args struct { | ||
ctx context.Context | ||
deploymentID string | ||
body EmbeddingsOptions | ||
options *ClientGetEmbeddingsOptions | ||
} | ||
deploymentID := "embedding" | ||
cred := KeyCredential{APIKey: apiKey} | ||
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) | ||
if err != nil { | ||
log.Fatalf("%v", err) | ||
} | ||
tests := []struct { | ||
name string | ||
client *Client | ||
args args | ||
want ClientGetEmbeddingsResponse | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "Embeddings", | ||
client: client, | ||
args: args{ | ||
ctx: context.TODO(), | ||
deploymentID: "embedding", | ||
body: EmbeddingsOptions{ | ||
Input: "Your text string goes here", | ||
Model: to.Ptr("text-similarity-curie-001"), | ||
}, | ||
options: nil, | ||
}, | ||
want: ClientGetEmbeddingsResponse{ | ||
Embeddings{ | ||
Data: []*EmbeddingItem{}, | ||
Usage: &EmbeddingsUsage{}, | ||
}, | ||
}, | ||
wantErr: false, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if len(got.Embeddings.Data[0].Embedding) != 4096 { | ||
t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data)) | ||
return | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
//go:build go1.18 | ||
// +build go1.18 | ||
|
||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See License.txt in the project root for license information. | ||
|
||
package azopenai | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"reflect" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/Azure/azure-sdk-for-go/sdk/azcore" | ||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" | ||
) | ||
|
||
func TestNewClient(t *testing.T) { | ||
type args struct { | ||
endpoint string | ||
credential azcore.TokenCredential | ||
deploymentID string | ||
options *ClientOptions | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
want *Client | ||
wantErr bool | ||
}{ | ||
// TODO: Add test cases. | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if !reflect.DeepEqual(got, tt.want) { | ||
t.Errorf("NewClient() = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestNewClientWithKeyCredential(t *testing.T) { | ||
type args struct { | ||
endpoint string | ||
credential KeyCredential | ||
deploymentID string | ||
options *ClientOptions | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
want *Client | ||
wantErr bool | ||
}{ | ||
// TODO: Add test cases. | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("NewClientWithKeyCredential() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if !reflect.DeepEqual(got, tt.want) { | ||
t.Errorf("NewClientWithKeyCredential() = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestClient_GetCompletionsStream(t *testing.T) { | ||
body := CompletionsOptions{ | ||
Prompt: []*string{to.Ptr("What is Azure OpenAI?")}, | ||
MaxTokens: to.Ptr(int32(2048 - 127)), | ||
Temperature: to.Ptr(float32(0.0)), | ||
} | ||
cred := KeyCredential{APIKey: apiKey} | ||
deploymentID := "text-davinci-003" | ||
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil) | ||
if err != nil { | ||
t.Errorf("NewClientWithKeyCredential() error = %v", err) | ||
return | ||
} | ||
response, err := client.GetCompletionsStream(context.TODO(), body, nil) | ||
if err != nil { | ||
t.Errorf("Client.GetCompletionsStream() error = %v", err) | ||
return | ||
} | ||
reader := response.Events | ||
defer reader.Close() | ||
|
||
var sb strings.Builder | ||
var eventCount int | ||
for { | ||
event, err := reader.Read() | ||
if err == io.EOF { | ||
break | ||
} | ||
eventCount++ | ||
if err != nil { | ||
t.Errorf("reader.Read() error = %v", err) | ||
return | ||
} | ||
sb.WriteString(*event.Choices[0].Text) | ||
} | ||
got := sb.String() | ||
const want = "\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models." | ||
if got != want { | ||
i := 0 | ||
for i < len(got) && i < len(want) && got[i] == want[i] { | ||
i++ | ||
} | ||
t.Errorf("Client.GetCompletionsStream() text[%d] = %c, want %c", i, got[i], want[i]) | ||
} | ||
if eventCount != 86 { | ||
t.Errorf("Client.GetCompletionsStream() got = %v, want %v", eventCount, 1) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY= | ||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= | ||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= | ||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= | ||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= | ||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= | ||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= | ||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= | ||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= | ||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= | ||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= | ||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= | ||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= | ||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= | ||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= | ||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= | ||
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= | ||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= | ||
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= | ||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | ||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= |
Oops, something went wrong.