From 40b2199c30e694848afc0d11b09fd0d70f143366 Mon Sep 17 00:00:00 2001 From: Lasse Hyldahl Jensen Date: Thu, 19 Sep 2024 09:55:07 +0200 Subject: [PATCH] Support binary image URLs without copying image data (#1) * Support binary image URL * Use ChatCompletionRequest in test * Cleanup * Cleanup * Comment out encodeImage function * Split test into 2 * Don't log request --- chat.go | 33 +++++++++++++++++- chat_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/chat.go b/chat.go index dc60f35b9..61f9af335 100644 --- a/chat.go +++ b/chat.go @@ -1,7 +1,9 @@ package openai import ( + "bytes" "context" + "encoding/base64" "encoding/json" "errors" "net/http" @@ -62,10 +64,39 @@ const ( ) type ChatMessageImageURL struct { - URL string `json:"url,omitempty"` + // URL can be a string or a BinaryImageURL object. + URL any `json:"url,omitempty"` Detail ImageURLDetail `json:"detail,omitempty"` } +type BinaryImageURL struct { + MimeType string + Data []byte +} + +func (b BinaryImageURL) MarshalJSON() ([]byte, error) { + encodedLength := base64.StdEncoding.EncodedLen(len(b.Data)) + buf := bytes.NewBuffer(make([]byte, 0, 15+len(b.MimeType)+encodedLength)) + + buf.WriteString(`"data:`) + buf.WriteString(b.MimeType) + buf.WriteString(`;base64,`) + + // base64 encode data and write it to the buffer + encoder := base64.NewEncoder(base64.StdEncoding, buf) + _, err := encoder.Write(b.Data) + if err != nil { + return nil, err + } + err = encoder.Close() + if err != nil { + return nil, err + } + + buf.WriteString(`"`) + return buf.Bytes(), nil +} + type ChatMessagePartType string const ( diff --git a/chat_test.go b/chat_test.go index 37dc09d4d..e24b24b0f 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,7 +1,10 @@ package openai_test import ( + "bytes" "context" + "crypto/rand" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -553,3 +556,99 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionRequest_MarshalJSON_LargeImage_Base64(t *testing.T) { + // generate 20 mb of random data + imageData := generateEncodedData(t, 20*1024*1024) + imageURL := encodeImage(t, "image/png", imageData) + req := generateRequestWithImage(imageURL) + + _, err := json.Marshal(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func TestChatCompletionRequest_MarshalJSON_LargeImage_Binary(t *testing.T) { + // generate 20 mb of random data + imageData := generateEncodedData(t, 20*1024*1024) + imageURL := openai.BinaryImageURL{ + MimeType: "image/png", + Data: imageData, + } + req := generateRequestWithImage(imageURL) + + _, err := json.Marshal(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func generateEncodedData(t *testing.T, size int64) []byte { + data := make([]byte, size) + _, err := rand.Read(data) + if err != nil { + t.Fatalf("Failed to generate random data: %v", err) + } + return data +} + +func generateRequestWithImage(imageURL any) openai.ChatCompletionRequest { + msg := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello, world!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: imageURL, + Detail: openai.ImageURLDetailHigh, + }, + }, + }, + Name: "test-name", + FunctionCall: &openai.FunctionCall{ + Name: "test-function", + Arguments: `{"arg1":"value1"}`, + }, + ToolCalls: []openai.ToolCall{ + { + ID: "tool1", + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{ + Name: "tool-function", + Arguments: `{"arg2":"value2"}`, + }, + }, + }, + ToolCallID: "tool-call-id", + } + req := openai.ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessage{msg}, + } + return req +} + +func encodeImage(t *testing.T, mimeType string, data []byte) []byte { + encodedLength := base64.StdEncoding.EncodedLen(len(data)) + buf := bytes.NewBuffer(make([]byte, 0, 13+len(mimeType)+encodedLength)) + + buf.WriteString(`data:`) + buf.WriteString(mimeType) + buf.WriteString(`;base64,`) + + // base64 encode data and write it to the buffer + encoder := base64.NewEncoder(base64.StdEncoding, buf) + _, err := encoder.Write(data) + if err != nil { + t.Fatalf("Failed to encode image: %v", err) + } + err = encoder.Close() + if err != nil { + t.Fatalf("Failed to encode image: %v", err) + } + return buf.Bytes() +}