Skip to content

Commit

Permalink
Add tests -- WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekistler committed Jun 6, 2023
1 parent 3db25b0 commit e90a278
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 12 deletions.
222 changes: 222 additions & 0 deletions sdk/cognitiveservices/azopenai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package azopenai

import (
"context"
"log"
"os"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

var (
endpoint = os.Getenv("AOAI_ENDPOINT")
apiKey = os.Getenv("AOAI_API_KEY")
)

func TestClient_GetChatCompletions(t *testing.T) {
type args struct {
ctx context.Context
deploymentID string
body ChatCompletionsOptions
options *ClientGetChatCompletionsOptions
}
cred := KeyCredential{APIKey: apiKey}
deploymentID := "gpt-35-turbo"
chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil)
if err != nil {
log.Fatalf("%v", err)
}
tests := []struct {
name string
client *Client
args args
want ClientGetChatCompletionsResponse
wantErr bool
}{
{
name: "ChatCompletions",
client: chatClient,
args: args{
ctx: context.TODO(),
deploymentID: "gpt-35-turbo",
body: ChatCompletionsOptions{
Messages: []*ChatMessage{
{
Role: to.Ptr(ChatRole("user")),
Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."),
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
},
options: nil,
},
want: ClientGetChatCompletionsResponse{
ChatCompletions: ChatCompletions{
Choices: []*ChatChoice{
{
Message: &ChatChoiceMessage{
Role: to.Ptr(ChatRole("assistant")),
Content: to.Ptr("1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100."),
},
Index: to.Ptr(int32(0)),
FinishReason: to.Ptr(CompletionsFinishReason("stop")),
},
},
Usage: &CompletionsUsage{
CompletionTokens: to.Ptr(int32(299)),
PromptTokens: to.Ptr(int32(37)),
TotalTokens: to.Ptr(int32(336)),
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.client.GetChatCompletions(tt.args.ctx, tt.args.body, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("Client.GetChatCompletions() error = %v, wantErr %v", err, tt.wantErr)
return
}
opts := cmpopts.IgnoreFields(ChatCompletions{}, "Created", "ID")
if diff := cmp.Diff(tt.want.ChatCompletions, got.ChatCompletions, opts); diff != "" {
t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff)
}
})
}
}

func TestClient_GetCompletions(t *testing.T) {
type args struct {
ctx context.Context
deploymentID string
body CompletionsOptions
options *ClientGetCompletionsOptions
}
cred := KeyCredential{APIKey: apiKey}
deploymentID := "text-davinci-003"
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil)
if err != nil {
log.Fatalf("%v", err)
}
tests := []struct {
name string
client *Client
args args
want ClientGetCompletionsResponse
wantErr bool
}{
{
name: "chatbot",
client: client,
args: args{
ctx: context.TODO(),
deploymentID: deploymentID,
body: CompletionsOptions{
Prompt: []*string{to.Ptr("What is Azure OpenAI?")},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
},
options: nil,
},
want: ClientGetCompletionsResponse{
Completions: Completions{
Choices: []*Choice{
{
Text: to.Ptr("\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."),
Index: to.Ptr(int32(0)),
FinishReason: to.Ptr(CompletionsFinishReason("stop")),
Logprobs: nil,
},
},
Usage: &CompletionsUsage{
CompletionTokens: to.Ptr(int32(85)),
PromptTokens: to.Ptr(int32(6)),
TotalTokens: to.Ptr(int32(91)),
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.client.GetCompletions(tt.args.ctx, tt.args.body, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("Client.GetCompletions() error = %v, wantErr %v", err, tt.wantErr)
return
}
opts := cmpopts.IgnoreFields(Completions{}, "Created", "ID")
if diff := cmp.Diff(tt.want.Completions, got.Completions, opts); diff != "" {
t.Errorf("Client.GetCompletions(): -want, +got:\n%s", diff)
}
})
}
}

func TestClient_GetEmbeddings(t *testing.T) {
type args struct {
ctx context.Context
deploymentID string
body EmbeddingsOptions
options *ClientGetEmbeddingsOptions
}
deploymentID := "embedding"
cred := KeyCredential{APIKey: apiKey}
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil)
if err != nil {
log.Fatalf("%v", err)
}
tests := []struct {
name string
client *Client
args args
want ClientGetEmbeddingsResponse
wantErr bool
}{
{
name: "Embeddings",
client: client,
args: args{
ctx: context.TODO(),
deploymentID: "embedding",
body: EmbeddingsOptions{
Input: "Your text string goes here",
Model: to.Ptr("text-similarity-curie-001"),
},
options: nil,
},
want: ClientGetEmbeddingsResponse{
Embeddings{
Data: []*EmbeddingItem{},
Usage: &EmbeddingsUsage{},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.client.GetEmbeddings(tt.args.ctx, tt.args.body, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("Client.GetEmbeddings() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got.Embeddings.Data[0].Embedding) != 4096 {
t.Errorf("Client.GetEmbeddings() len(Data) want 4096, got %d", len(got.Embeddings.Data))
return
}
})
}
}
125 changes: 125 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package azopenai

import (
"context"
"io"
"reflect"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)

func TestNewClient(t *testing.T) {
type args struct {
endpoint string
credential azcore.TokenCredential
deploymentID string
options *ClientOptions
}
tests := []struct {
name string
args args
want *Client
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewClient(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewClient() = %v, want %v", got, tt.want)
}
})
}
}

func TestNewClientWithKeyCredential(t *testing.T) {
type args struct {
endpoint string
credential KeyCredential
deploymentID string
options *ClientOptions
}
tests := []struct {
name string
args args
want *Client
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewClientWithKeyCredential(tt.args.endpoint, tt.args.credential, tt.args.deploymentID, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("NewClientWithKeyCredential() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewClientWithKeyCredential() = %v, want %v", got, tt.want)
}
})
}
}

func TestClient_GetCompletionsStream(t *testing.T) {
body := CompletionsOptions{
Prompt: []*string{to.Ptr("What is Azure OpenAI?")},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
}
cred := KeyCredential{APIKey: apiKey}
deploymentID := "text-davinci-003"
client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, nil)
if err != nil {
t.Errorf("NewClientWithKeyCredential() error = %v", err)
return
}
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
if err != nil {
t.Errorf("Client.GetCompletionsStream() error = %v", err)
return
}
reader := response.Events
defer reader.Close()

var sb strings.Builder
var eventCount int
for {
event, err := reader.Read()
if err == io.EOF {
break
}
eventCount++
if err != nil {
t.Errorf("reader.Read() error = %v", err)
return
}
sb.WriteString(*event.Choices[0].Text)
}
got := sb.String()
const want = "\n\nAzure OpenAI is a platform from Microsoft that provides access to OpenAI's artificial intelligence (AI) technologies. It enables developers to build, train, and deploy AI models in the cloud. Azure OpenAI provides access to OpenAI's powerful AI technologies, such as GPT-3, which can be used to create natural language processing (NLP) applications, computer vision models, and reinforcement learning models."
if got != want {
i := 0
for i < len(got) && i < len(want) && got[i] == want[i] {
i++
}
t.Errorf("Client.GetCompletionsStream() text[%d] = %c, want %c", i, got[i], want[i])
}
if eventCount != 86 {
t.Errorf("Client.GetCompletionsStream() got = %v, want %v", eventCount, 1)
}
}
11 changes: 7 additions & 4 deletions sdk/cognitiveservices/azopenai/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ module github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai

go 1.18

require github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0
github.com/google/go-cmp v0.5.9
)

require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect
)
18 changes: 10 additions & 8 deletions sdk/cognitiveservices/azopenai/go.sum
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
Loading

0 comments on commit e90a278

Please sign in to comment.