Skip to content

Commit

Permalink
[azopenai] Fixing example's typo, creating a constructor for KeyCrede…
Browse files Browse the repository at this point in the history
…ntial (#21042)

Fixing a few things based on feedback from @JeffreyRichter.
- Example had a typo and a hardcoded calculation.
- Adding a constructor for KeyCredential.
  • Loading branch information
richardpark-msft authored Jun 21, 2023
1 parent 9c53d0f commit 1e4fb7e
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 30 deletions.
8 changes: 3 additions & 5 deletions sdk/cognitiveservices/azopenai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ the [Code of Conduct FAQ][coc_faq] or contact [[email protected]][coc_conta
comments.

<!-- LINKS -->
<!-- https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/cognitiveservices/azopenai ->
<!-- https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai -->
[azure_openai_access]: https://learn.microsoft.com/azure/cognitive-services/openai/overview#how-do-i-get-access-to-azure-openai
[azopenai_repo]: https://github.com/Azure/azure-sdk-for-go/tree/main/sdk
[azopenai_pkg_go]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk
[azopenai_repo]: https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/cognitiveservices/azopenai
[azopenai_pkg_go]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai
[azure_identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity
[azure_sub]: https://azure.microsoft.com/free/
[openai_docs]: https://learn.microsoft.com/azure/cognitive-services/openai
Expand All @@ -104,4 +102,4 @@ comments.
[cla]: https://cla.microsoft.com
[coc]: https://opensource.microsoft.com/codeofconduct/
[coc_faq]: https://opensource.microsoft.com/codeofconduct/faq/
[coc_contact]: mailto:[email protected]
[coc_contact]: mailto:[email protected]
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_0bc6dc4171"
"Tag": "go/cognitiveservices/azopenai_49bcacb061"
}
25 changes: 19 additions & 6 deletions sdk/cognitiveservices/azopenai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
func TestClient_GetChatCompletions(t *testing.T) {
deploymentID := "gpt-35-turbo"

cred := KeyCredential{APIKey: apiKey}
cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

chatClient, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t))
require.NoError(t, err)

Expand Down Expand Up @@ -158,7 +160,9 @@ func testGetChatCompletions(t *testing.T, chatClient *Client, modelOrDeployment
}

func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
cred := KeyCredential{APIKey: apiKey}
cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
require.NoError(t, err)

Expand All @@ -179,7 +183,9 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
}

func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
cred := KeyCredential{APIKey: apiKey}
cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

chatClient, err := NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
require.NoError(t, err)

Expand All @@ -197,7 +203,9 @@ func TestClient_GetCompletions(t *testing.T) {
body CompletionsOptions
options *GetCompletionsOptions
}
cred := KeyCredential{APIKey: apiKey}
cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t))
if err != nil {
log.Fatalf("%v", err)
Expand Down Expand Up @@ -267,7 +275,9 @@ func TestClient_GetEmbeddings(t *testing.T) {
// model deployment points to `text-similarity-curie-001`
deploymentID := "embedding"

cred := KeyCredential{APIKey: apiKey}
cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

client, err := NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t))
require.NoError(t, err)

Expand Down Expand Up @@ -330,7 +340,10 @@ func newOpenAIClientForTest(t *testing.T) *Client {
t.Skipf("OPENAI_API_KEY not defined, skipping OpenAI public endpoint test")
}

chatClient, err := NewClientForOpenAI(openAIEndpoint, KeyCredential{APIKey: openAIKey}, newClientOptionsForTest(t))
cred, err := NewKeyCredential(openAIKey)
require.NoError(t, err)

chatClient, err := NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

return chatClient
Expand Down
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/custom_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func newOpenAIPolicy(cred KeyCredential) *openAIPolicy {
func (b *openAIPolicy) Do(req *policy.Request) (*http.Response, error) {
q := req.Raw().URL.Query()
q.Del("api-version")
req.Raw().Header.Set("authorization", "Bearer "+b.cred.APIKey)
req.Raw().Header.Set("authorization", "Bearer "+b.cred.apiKey)
return req.Next()
}

Expand Down
7 changes: 5 additions & 2 deletions sdk/cognitiveservices/azopenai/custom_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/stretchr/testify/require"
)

func TestNewClient(t *testing.T) {
Expand Down Expand Up @@ -78,10 +79,12 @@ func TestNewClientWithKeyCredential(t *testing.T) {
func TestClient_GetCompletionsStream(t *testing.T) {
body := CompletionsOptions{
Prompt: []*string{to.Ptr("What is Azure OpenAI?")},
MaxTokens: to.Ptr(int32(2048 - 127)),
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
}
cred := KeyCredential{APIKey: apiKey}

cred, err := NewKeyCredential(apiKey)
require.NoError(t, err)

client, err := NewClientWithKeyCredential(endpoint, cred, streamingModelDeployment, newClientOptionsForTest(t))
if err != nil {
Expand Down
22 changes: 14 additions & 8 deletions sdk/cognitiveservices/azopenai/examples_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
func ExampleNewClientForOpenAI() {
// NOTE: this constructor creates a client that connects to the public OpenAI endpoint.
// To connect to an Azure OpenAI endpoint, use azopenai.NewClient() or azopenai.NewClientWithyKeyCredential.
keyCredential := azopenai.KeyCredential{
APIKey: "open-ai-apikey",
keyCredential, err := azopenai.NewKeyCredential("<OpenAI-APIKey>")

if err != nil {
panic(err)
}

client, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", keyCredential, nil)
Expand Down Expand Up @@ -53,8 +55,10 @@ func ExampleNewClient() {
func ExampleNewClientWithKeyCredential() {
// NOTE: this constructor creates a client that connects to an Azure OpenAI endpoint.
// To connect to the public OpenAI endpoint, use azopenai.NewClientForOpenAI
keyCredential := azopenai.KeyCredential{
APIKey: "Azure OpenAI apikey",
keyCredential, err := azopenai.NewKeyCredential("<Azure-OpenAI-APIKey>")

if err != nil {
panic(err)
}

modelDeploymentID := "model deployment ID"
Expand All @@ -78,8 +82,10 @@ func ExampleClient_GetCompletionsStream() {
return
}

keyCredential := azopenai.KeyCredential{
APIKey: azureOpenAIKey,
keyCredential, err := azopenai.NewKeyCredential(azureOpenAIKey)

if err != nil {
panic(err)
}

client, err := azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, modelDeploymentID, nil)
Expand All @@ -90,7 +96,7 @@ func ExampleClient_GetCompletionsStream() {

resp, err := client.GetCompletionsStream(context.TODO(), azopenai.CompletionsOptions{
Prompt: []*string{to.Ptr("What is Azure OpenAI?")},
MaxTokens: to.Ptr(int32(2048 - 127)),
MaxTokens: to.Ptr(int32(2048)),
Temperature: to.Ptr(float32(0.0)),
}, nil)

Expand All @@ -102,7 +108,7 @@ func ExampleClient_GetCompletionsStream() {
entry, err := resp.CompletionsStream.Read()

if errors.Is(err, io.EOF) {
fmt.Printf("More more completions")
fmt.Printf("\n *** No more completions ***\n")
break
}

Expand Down
12 changes: 9 additions & 3 deletions sdk/cognitiveservices/azopenai/policy_apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ import (

// KeyCredential is used when doing APIKey-based authentication.
type KeyCredential struct {
// APIKey is the api key for the client.
APIKey string
// apiKey is the api key for the client.
apiKey string
}

// NewKeyCredential creates a KeyCredential containing an API key for
// either Azure OpenAI or OpenAI.
func NewKeyCredential(apiKey string) (KeyCredential, error) {
return KeyCredential{apiKey: apiKey}, nil
}

// apiKeyPolicy authorizes requests with an API key acquired from a KeyCredential.
Expand All @@ -35,6 +41,6 @@ func newAPIKeyPolicy(cred KeyCredential, header string) *apiKeyPolicy {

// Do returns a function which authorizes req with a token from the policy's credential
func (b *apiKeyPolicy) Do(req *policy.Request) (*http.Response, error) {
req.Raw().Header.Set(b.header, b.cred.APIKey)
req.Raw().Header.Set(b.header, b.cred.apiKey)
return req.Next()
}
12 changes: 8 additions & 4 deletions sdk/cognitiveservices/azopenai/policy_apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/stretchr/testify/require"

"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)
Expand All @@ -23,7 +24,9 @@ func TestNewAPIKeyPolicy(t *testing.T) {
header string
cred KeyCredential
}
simpleCred := KeyCredential{APIKey: "apiKey"}
simpleCred, err := NewKeyCredential("apiKey")
require.NoError(t, err)

simpleHeader := "headerName"
tests := []struct {
name string
Expand Down Expand Up @@ -55,9 +58,10 @@ func TestAPIKeyPolicy_Success(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithStatusCode(http.StatusOK))
cred := KeyCredential{
APIKey: "secret",
}

cred, err := NewKeyCredential("secret")
require.NoError(t, err)

authPolicy := newAPIKeyPolicy(cred, "api-key")
pipeline := runtime.NewPipeline(
"testmodule",
Expand Down

0 comments on commit 1e4fb7e

Please sign in to comment.