Skip to content

Commit

Permalink
Support binary image URLs without copying image data (#1)
Browse files Browse the repository at this point in the history
* Support binary image URL

* Use ChatCompletionRequest in test

* Cleanup

* Cleanup

* Comment out encodeImage function

* Split test into 2

* Don't log request
  • Loading branch information
lasse-it authored Sep 19, 2024
1 parent a5fb553 commit 40b2199
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
33 changes: 32 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package openai

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
Expand Down Expand Up @@ -62,10 +64,39 @@ const (
)

type ChatMessageImageURL struct {
URL string `json:"url,omitempty"`
// URL can be a string or a BinaryImageURL object.
URL any `json:"url,omitempty"`
Detail ImageURLDetail `json:"detail,omitempty"`
}

type BinaryImageURL struct {
MimeType string
Data []byte
}

func (b BinaryImageURL) MarshalJSON() ([]byte, error) {
encodedLength := base64.StdEncoding.EncodedLen(len(b.Data))
buf := bytes.NewBuffer(make([]byte, 0, 15+len(b.MimeType)+encodedLength))

buf.WriteString(`"data:`)
buf.WriteString(b.MimeType)
buf.WriteString(`;base64,`)

// base64 encode data and write it to the buffer
encoder := base64.NewEncoder(base64.StdEncoding, buf)
_, err := encoder.Write(b.Data)
if err != nil {
return nil, err
}
err = encoder.Close()
if err != nil {
return nil, err
}

buf.WriteString(`"`)
return buf.Bytes(), nil
}

type ChatMessagePartType string

const (
Expand Down
99 changes: 99 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package openai_test

import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -553,3 +556,99 @@ func TestFinishReason(t *testing.T) {
}
}
}

func TestChatCompletionRequest_MarshalJSON_LargeImage_Base64(t *testing.T) {
// generate 20 mb of random data
imageData := generateEncodedData(t, 20*1024*1024)
imageURL := encodeImage(t, "image/png", imageData)
req := generateRequestWithImage(imageURL)

_, err := json.Marshal(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}

func TestChatCompletionRequest_MarshalJSON_LargeImage_Binary(t *testing.T) {
// generate 20 mb of random data
imageData := generateEncodedData(t, 20*1024*1024)
imageURL := openai.BinaryImageURL{
MimeType: "image/png",
Data: imageData,
}
req := generateRequestWithImage(imageURL)

_, err := json.Marshal(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}

func generateEncodedData(t *testing.T, size int64) []byte {
data := make([]byte, size)
_, err := rand.Read(data)
if err != nil {
t.Fatalf("Failed to generate random data: %v", err)
}
return data
}

func generateRequestWithImage(imageURL any) openai.ChatCompletionRequest {
msg := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: "Hello, world!",
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: imageURL,
Detail: openai.ImageURLDetailHigh,
},
},
},
Name: "test-name",
FunctionCall: &openai.FunctionCall{
Name: "test-function",
Arguments: `{"arg1":"value1"}`,
},
ToolCalls: []openai.ToolCall{
{
ID: "tool1",
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{
Name: "tool-function",
Arguments: `{"arg2":"value2"}`,
},
},
},
ToolCallID: "tool-call-id",
}
req := openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{msg},
}
return req
}

func encodeImage(t *testing.T, mimeType string, data []byte) []byte {
encodedLength := base64.StdEncoding.EncodedLen(len(data))
buf := bytes.NewBuffer(make([]byte, 0, 13+len(mimeType)+encodedLength))

buf.WriteString(`data:`)
buf.WriteString(mimeType)
buf.WriteString(`;base64,`)

// base64 encode data and write it to the buffer
encoder := base64.NewEncoder(base64.StdEncoding, buf)
_, err := encoder.Write(data)
if err != nil {
t.Fatalf("Failed to encode image: %v", err)
}
err = encoder.Close()
if err != nil {
t.Fatalf("Failed to encode image: %v", err)
}
return buf.Bytes()
}

0 comments on commit 40b2199

Please sign in to comment.