Skip to content

Commit

Permalink
[azopenai] Re-enable the tests for vision and expand it to include Az…
Browse files Browse the repository at this point in the history
…ure. (#22130)

- Do proper test recordings for audio so we can enable those.
- Bring back vision, and also target Azure.

Overall, this brings our test coverage back up to 32% from 24%.

Fixes #21598
  • Loading branch information
richardpark-msft authored Dec 13, 2023
1 parent 58ac9ec commit ed5bc4b
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 46 deletions.
2 changes: 1 addition & 1 deletion eng/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
},
{
"Name": "azopenai",
"CoverageGoal": 0.24
"CoverageGoal": 0.32
},
{
"Name": "aztemplate",
Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/ai/azopenai",
"Tag": "go/ai/azopenai_d4fd4783ec"
"Tag": "go/ai/azopenai_b42da78821"
}
17 changes: 0 additions & 17 deletions sdk/ai/azopenai/client_audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,15 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/stretchr/testify/require"
)

func TestClient_GetAudioTranscription_AzureOpenAI(t *testing.T) {
if recording.GetRecordMode() != recording.LiveMode {
t.Skipf("Recording needs to be revisited for multipart: https://github.com/Azure/azure-sdk-for-go/issues/21598")
}

client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranscriptionTests(t, client, azureOpenAI.Whisper.Model)
}

func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
if recording.GetRecordMode() != recording.LiveMode {
t.Skipf("Recording needs to be revisited for multipart: https://github.com/Azure/azure-sdk-for-go/issues/21598")
}

client := newOpenAIClientForTest(t)

mp3Bytes, err := os.ReadFile(`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`)
Expand All @@ -50,19 +41,11 @@ func TestClient_GetAudioTranscription_OpenAI(t *testing.T) {
}

func TestClient_GetAudioTranslation_AzureOpenAI(t *testing.T) {
if recording.GetRecordMode() != recording.LiveMode {
t.Skipf("Recording needs to be revisited for multipart: https://github.com/Azure/azure-sdk-for-go/issues/21598")
}

client := newTestClient(t, azureOpenAI.Whisper.Endpoint, withForgivingRetryOption())
runTranslationTests(t, client, azureOpenAI.Whisper.Model)
}

func TestClient_GetAudioTranslation_OpenAI(t *testing.T) {
if recording.GetRecordMode() != recording.LiveMode {
t.Skipf("Recording needs to be revisited for multipart: https://github.com/Azure/azure-sdk-for-go/issues/21598")
}

client := newOpenAIClientForTest(t)

mp3Bytes, err := os.ReadFile(`testdata/sampledata_audiofiles_myVoiceIsMyPassportVerifyMe01.mp3`)
Expand Down
59 changes: 32 additions & 27 deletions sdk/ai/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,38 +230,43 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
}

func TestClient_OpenAI_GetChatCompletions_Vision(t *testing.T) {
if recording.GetRecordMode() == recording.LiveMode {
// we're having an issue right now with this preview feature. I luckily got
// a recording of it but it won't run in live mode at this moment.
t.Skipf("Skipping %s because of a temp live outage with vision preview", t.Name())
}

imageURL := "https://www.bing.com/th?id=OHR.BradgateFallow_EN-US3932725763_1920x1080.jpg"
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string) {
imageURL := "https://www.bing.com/th?id=OHR.BradgateFallow_EN-US3932725763_1920x1080.jpg"

chatClient := newOpenAIClientForTest(t)
content := azopenai.NewChatRequestUserMessageContent([]azopenai.ChatCompletionRequestMessageContentPartClassification{
&azopenai.ChatCompletionRequestMessageContentPartText{
Text: to.Ptr("Describe this image"),
},
&azopenai.ChatCompletionRequestMessageContentPartImage{
ImageURL: &azopenai.ChatCompletionRequestMessageContentPartImageURL{
URL: &imageURL,
content := azopenai.NewChatRequestUserMessageContent([]azopenai.ChatCompletionRequestMessageContentPartClassification{
&azopenai.ChatCompletionRequestMessageContentPartText{
Text: to.Ptr("Describe this image"),
},
},
})
&azopenai.ChatCompletionRequestMessageContentPartImage{
ImageURL: &azopenai.ChatCompletionRequestMessageContentPartImageURL{
URL: &imageURL,
},
},
})

resp, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: content,
resp, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestUserMessage{
Content: content,
},
},
},
DeploymentName: to.Ptr("gpt-4-vision-preview"),
}, nil)
require.NoError(t, err)
require.NotEmpty(t, resp.Choices[0].Message.Content)
DeploymentName: to.Ptr(deploymentName),
}, nil)
require.NoError(t, err)
require.NotEmpty(t, resp.Choices[0].Message.Content)

t.Logf(*resp.Choices[0].Message.Content)
}

t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testFn(t, chatClient, openAI.Vision.Model)
})

t.Logf(*resp.Choices[0].Message.Content)
t.Run("AOAI", func(t *testing.T) {
chatClient := newTestClient(t, azureOpenAI.Vision.Endpoint)
testFn(t, chatClient, azureOpenAI.Vision.Model)
})
}

func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
Expand Down
69 changes: 69 additions & 0 deletions sdk/ai/azopenai/client_shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
package azopenai_test

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"regexp"
Expand All @@ -16,6 +19,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/joho/godotenv"
Expand All @@ -42,6 +46,7 @@ type testVars struct {
Cognitive azopenai.AzureCognitiveSearchChatExtensionConfiguration
Whisper endpointWithModel
DallE endpointWithModel
Vision endpointWithModel

ChatCompletionsRAI endpointWithModel // at the moment this is Azure only

Expand Down Expand Up @@ -165,6 +170,7 @@ func newTestVars(prefix string) testVars {

DallE: getEndpointWithModel("DALLE", azure),
Whisper: getEndpointWithModel("WHISPER", azure),
Vision: getEndpointWithModel("VISION", azure),
}

if azure {
Expand Down Expand Up @@ -215,6 +221,11 @@ func initEnvVars() {
Model: "dall-e-3",
}

azureOpenAI.Vision = endpointWithModel{
Endpoint: azureOpenAI.Endpoint,
Model: "gpt-4-vision-preview",
}

openAI.Endpoint = endpoint{
APIKey: fakeAPIKey,
URL: fakeEndpoint,
Expand All @@ -233,6 +244,8 @@ func initEnvVars() {
Model: "dall-e-3",
}

openAI.Vision = azureOpenAI.Vision

azureOpenAI.Completions = "text-davinci-003"
openAI.Completions = "text-davinci-003"

Expand Down Expand Up @@ -263,6 +276,13 @@ func initEnvVars() {
}
}

type MultipartRecordingPolicy struct {
}

func (mrp MultipartRecordingPolicy) Do(req *http.Request) (*http.Response, error) {
return nil, nil
}

// newRecordingTransporter sets up our recording policy to sanitize endpoints and any parts of the response that might
// involve UUIDs that would make the response/request inconsistent.
func newRecordingTransporter(t *testing.T) policy.Transporter {
Expand All @@ -284,6 +304,7 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
azureOpenAI.Whisper.Endpoint.URL,
azureOpenAI.DallE.Endpoint.URL,
azureOpenAI.Vision.Endpoint.URL,
azureOpenAI.ChatCompletionsOYD.Endpoint.URL,
azureOpenAI.ChatCompletionsRAI.Endpoint.URL,
}
Expand Down Expand Up @@ -357,6 +378,7 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {

co.Transport = &http.Client{Transport: tp}
} else {
co.PerRetryPolicies = append(co.PerRetryPolicies, &mimeTypeRecordingPolicy{})
co.Transport = newRecordingTransporter(t)
}

Expand Down Expand Up @@ -417,3 +439,50 @@ func skipNowIfThrottled(t *testing.T, err error) {
t.Skipf("OpenAI resource overloaded, skipping this test")
}
}

type mimeTypeRecordingPolicy struct{}

// Do changes out the boundary for a multipart message. This makes it simpler to write
// recordings.
func (mrp *mimeTypeRecordingPolicy) Do(req *policy.Request) (*http.Response, error) {
if recording.GetRecordMode() == recording.LiveMode {
// this is strictly to make the IDs in the multipart body stable for test recordings.
return req.Next()
}

// we'll fix up the multipart to make it more predictable for test recordings.
// Content-Type: multipart/form-data; boundary=787c880ce3dd11f9b6384d625c399c8490fc8989ceb6b7d208ec7426c12e
mediaType, params, err := mime.ParseMediaType(req.Raw().Header[http.CanonicalHeaderKey("Content-type")][0])

if err != nil || mediaType != "multipart/form-data" {
// we'll just assume our policy doesn't apply here.
return req.Next()
}

origBoundary := params["boundary"]

if origBoundary == "" {
return nil, errors.New("Invalid use of this policy - no boundary was passed as part of the multipart mime type")
}

params["boundary"] = "boundary-for-recordings"

// now let's update the body itself - we'll just do a simple string replacement. The entire purpose of the boundary string is to provide a
// separator, which is distinct from the content.
body := req.Body()
defer body.Close()

origBody, err := io.ReadAll(body)

if err != nil {
return nil, err
}

newBody := bytes.ReplaceAll(origBody, []byte(origBoundary), []byte("boundary-for-recordings"))

if err := req.SetBody(streaming.NopCloser(bytes.NewReader(newBody)), mime.FormatMediaType(mediaType, params)); err != nil {
return nil, err
}

return req.Next()
}

0 comments on commit ed5bc4b

Please sign in to comment.