From cf2f6076a33da937718377415e88deb5ff35abf5 Mon Sep 17 00:00:00 2001 From: Ronald Mannak Date: Sat, 13 Apr 2024 08:21:26 -0700 Subject: [PATCH] Add support for GPT-4 Vision (#17) * Add Content * Make ChatContent public * Rename ChatContent -> MessageContent * Update ChatMessage * Fix init * Add init for converting image data to base64 * Return emtpy string in case of no text --- Sources/CleverBird/chat/ChatMessage.swift | 75 +++++++++- .../chat/ChatThread+tokenCount.swift | 27 +++- Sources/CleverBird/chat/ChatThread.swift | 10 ++ Sources/CleverBird/chat/MessageContent.swift | 90 ++++++++++++ .../CleverBirdTests/MessageContentTests.swift | 24 ++++ .../MessageEncodingTests.swift | 134 ++++++++++++++++++ .../OpenAIChatThreadTests.swift | 2 +- 7 files changed, 355 insertions(+), 7 deletions(-) create mode 100644 Sources/CleverBird/chat/MessageContent.swift create mode 100644 Tests/CleverBirdTests/MessageContentTests.swift create mode 100644 Tests/CleverBirdTests/MessageEncodingTests.swift diff --git a/Sources/CleverBird/chat/ChatMessage.swift b/Sources/CleverBird/chat/ChatMessage.swift index 322edbe..7072721 100644 --- a/Sources/CleverBird/chat/ChatMessage.swift +++ b/Sources/CleverBird/chat/ChatMessage.swift @@ -24,7 +24,7 @@ public struct ChatMessage: Codable, Identifiable { public let role: Role /// The contents of the message. `content` is required for all messages except assistant messages with function calls. - public let content: String? + public let content: Content? /// The name and arguments of a function that should be called, as generated by the model. public let functionCall: FunctionCall? @@ -36,14 +36,21 @@ public struct ChatMessage: Codable, Identifiable { content: String? = nil, id: String? = nil, functionCall: FunctionCall? = nil) throws { + try self.init(role: role, media: content != nil ? .text(content!) : nil, id: id, functionCall: functionCall) + } + public init(role: Role, + media: ChatMessage.Content?, + id: String? = nil, + functionCall: FunctionCall? = nil) throws { + // Validation: Content is required for all messages except assistant messages with function calls. - if content == nil && !(role == .assistant && functionCall != nil) { + if media == nil && !(role == .assistant && functionCall != nil) { throw CleverBirdError.invalidMessageContent } self.role = role - self.content = content + self.content = media self.name = functionCall?.name if role == .function { // If the role is "function" I need to set functionCall to nil, otherwise this will @@ -58,7 +65,9 @@ public struct ChatMessage: Codable, Identifiable { } else { var hasher = Hasher() hasher.combine(self.role) - hasher.combine(self.content ?? "") + if let content { + hasher.combine(content) + } let hashValue = abs(hasher.finalize()) let timestamp = Int(Date.now.timeIntervalSince1970*10000) @@ -69,7 +78,7 @@ public struct ChatMessage: Codable, Identifiable { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) self.role = try container.decode(Role.self, forKey: .role) - self.content = try container.decodeIfPresent(String.self, forKey: .content) + self.content = try container.decodeIfPresent(Content.self, forKey: .content) self.functionCall = try container.decodeIfPresent(FunctionCall.self, forKey: .functionCall) self.name = try container.decodeIfPresent(String.self, forKey: .name) self.id = "pending" @@ -92,3 +101,59 @@ extension ChatMessage: Equatable { && lhs.content == rhs.content } } + +extension ChatMessage { + + public enum Content: Codable, Equatable, CustomStringConvertible, Hashable { + + case text(String) + case media([MessageContent]) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let textContent = try? container.decode(String.self) { + self = .text(textContent) + } else if let chatContents = try? container.decode([MessageContent].self) { + self = .media(chatContents) + } else { + throw DecodingError.typeMismatch(MessageContent.self, DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Unsupported type for Content")) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .text(let text): + try container.encode(text) + case .media(let contents): + try container.encode(contents) + } + } + + public static func == (lhs: Content, rhs: Content) -> Bool { + switch (lhs, rhs) { + case (.text(let leftText), .text(let rightText)): + return leftText == rightText + case (.media(let leftMedia), .media(let rightMedia)): + return leftMedia == rightMedia + default: + return false + } + } + + public var description: String { + switch self { + case .media(let messageContents): + for messageContent in messageContents { + if case .text(let textValue) = messageContent { + return textValue + } + } + return "" + case .text(let text): + return text + } + } + } +} + diff --git a/Sources/CleverBird/chat/ChatThread+tokenCount.swift b/Sources/CleverBird/chat/ChatThread+tokenCount.swift index d1187d9..e4caf88 100644 --- a/Sources/CleverBird/chat/ChatThread+tokenCount.swift +++ b/Sources/CleverBird/chat/ChatThread+tokenCount.swift @@ -29,7 +29,32 @@ extension ChatThread { let roleTokens = try tokenEncoder.encode(text: message.role.rawValue).count let contentTokens: Int if let content = message.content { - contentTokens = try tokenEncoder.encode(text: content).count + switch content { + case .text(let text): + contentTokens = try tokenEncoder.encode(text: text).count + case .media(let media): + var count = 0 + for medium in media { + switch medium { + case .text(let text): + count += try tokenEncoder.encode(text: text).count + case .imageUrl(let url): + // See https://platform.openai.com/docs/guides/vision/calculating-costs + switch url.detail { + // TODO: calculate real values for auto and high + case .auto: + count += 1105 + case .high: + count += 1105 + case .low: + count += 85 + case .none: + count += 1105 + } + } + } + contentTokens = count + } } else if let functionCall = message.functionCall { let jsonEncoder = JSONEncoder() let jsonData = try jsonEncoder.encode(functionCall) diff --git a/Sources/CleverBird/chat/ChatThread.swift b/Sources/CleverBird/chat/ChatThread.swift index 073e2d6..8266691 100644 --- a/Sources/CleverBird/chat/ChatThread.swift +++ b/Sources/CleverBird/chat/ChatThread.swift @@ -32,6 +32,16 @@ public class ChatThread: Codable { } return self } + + @discardableResult + public func addUserMessage(_ media: [MessageContent]) -> Self { + do { + try addMessage(ChatMessage(role: .user, media: .media(media))) + } catch { + print(error.localizedDescription) + } + return self + } @discardableResult public func addAssistantMessage(_ content: String) -> Self { diff --git a/Sources/CleverBird/chat/MessageContent.swift b/Sources/CleverBird/chat/MessageContent.swift new file mode 100644 index 0000000..ad6af19 --- /dev/null +++ b/Sources/CleverBird/chat/MessageContent.swift @@ -0,0 +1,90 @@ +// +// ChatContent.swift +// +// +// Created by Ronald Mannak on 4/12/24. +// + +import Foundation + +public enum MessageContent: Hashable { + case text(String) + case imageUrl(URLDetail) +} + +extension MessageContent { + public enum ContentType: String, Codable, Hashable { + case text + case imageUrl = "image_url" + } + + public struct URLDetail: Codable, Equatable, Hashable { + + public enum Detail: String, Codable { + case low, high, auto + } + + let url: String + let detail: Detail? + + public init(url: String, detail: Detail? = nil) { + self.url = url + self.detail = detail + } + + public init(url: URL, detail: Detail? = nil) { + self.init(url: url.absoluteString, detail: detail) + } + + public init(imageData: Data, detail: Detail? = nil) { + let base64 = imageData.base64EncodedString() + self.init(url: "data:image/jpeg;base64,\(base64)", detail: detail) + } + } +} + +extension MessageContent: Codable { + + private enum CodingKeys: String, CodingKey { + case type, text, imageUrl + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(ContentType.self, forKey: .type) + + switch type { + case .text: + let text = try container.decode(String.self, forKey: .text) + self = .text(text) + case .imageUrl: + let imageUrl = try container.decode(URLDetail.self, forKey: .imageUrl) + self = .imageUrl(imageUrl) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .text(let text): + try container.encode(ContentType.text.rawValue, forKey: .type) + try container.encode(text, forKey: .text) + case .imageUrl(let urlDetail): + try container.encode(ContentType.imageUrl.rawValue, forKey: .type) + try container.encode(urlDetail, forKey: .imageUrl) + } + } +} + +extension MessageContent: Equatable { + public static func == (lhs: MessageContent, rhs: MessageContent) -> Bool { + switch (lhs, rhs) { + case (.text(let lhsText), .text(let rhsText)): + return lhsText == rhsText + case (.imageUrl(let lhsUrlDetail), .imageUrl(let rhsUrlDetail)): + return lhsUrlDetail == rhsUrlDetail + default: + return false + } + } +} diff --git a/Tests/CleverBirdTests/MessageContentTests.swift b/Tests/CleverBirdTests/MessageContentTests.swift new file mode 100644 index 0000000..f43c80b --- /dev/null +++ b/Tests/CleverBirdTests/MessageContentTests.swift @@ -0,0 +1,24 @@ +// +// MessageContentTests.swift +// +// +// Created by Ronald Mannak on 4/12/24. +// + +import Foundation +import XCTest +@testable import CleverBird + +class MessageContentTests: XCTestCase { + + func testURL() { + let content = MessageContent.URLDetail(url: URL(string: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")!) + XCTAssertEqual(content.url, "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg") + } + + func testBase64() { + let data = "Hello, world".data(using: .utf8)! + let content = MessageContent.URLDetail(imageData: data) + XCTAssertEqual(content.url, "") + } +} diff --git a/Tests/CleverBirdTests/MessageEncodingTests.swift b/Tests/CleverBirdTests/MessageEncodingTests.swift new file mode 100644 index 0000000..6fd4a2a --- /dev/null +++ b/Tests/CleverBirdTests/MessageEncodingTests.swift @@ -0,0 +1,134 @@ +// +// ContentEncodingTests.swift +// +// +// Created by Ronald Mannak on 4/12/24. +// + +import XCTest +@testable import CleverBird + +class ContentEncodingTests: XCTestCase { + + let text = """ + { + "type": "text", + "text": "What’s in this image?" + } + """.data(using: .utf8)! + + let imageURL = """ + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + } + } + """.data(using: .utf8)! + + let imageURLDetail = """ + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "detail": "high" + } + } + """.data(using: .utf8)! + + let imageData = """ + { + "type": "image_url", + "image_url": { + "url": "" + } + } + """.data(using: .utf8)! + + var encoder: JSONEncoder! + var decoder: JSONDecoder! + + override func setUp() { + encoder = JSONEncoder() + decoder = JSONDecoder() + encoder.keyEncodingStrategy = .convertToSnakeCase + decoder.keyDecodingStrategy = .convertFromSnakeCase + } + + func testTextDecoding() throws { + let object = try decoder.decode(MessageContent.self, from: text) + switch object { + case .imageUrl(_): + XCTFail() + case .text(let text): + XCTAssertEqual(text, "What’s in this image?") + } + } + + func testImageURLDecoding() throws { + _ = try decoder.decode(MessageContent.self, from: imageURL) + } + + func testImageURLDetailDecoding() throws { + _ = try decoder.decode(MessageContent.self, from: imageURLDetail) + } + + func testImageDataDecoding() throws { + _ = try decoder.decode(MessageContent.self, from: imageData) + } + + func testTextEncoding() throws { + let content = MessageContent.text("What’s in this image?") + let json = try encoder.encode(content) + + let object = try decoder.decode(MessageContent.self, from: json) + switch object { + case .imageUrl(_): + XCTFail() + case .text(let text): + XCTAssertEqual(text, "What’s in this image?") + } + } + + func testImageURLEncoding() throws { + let content = MessageContent.imageUrl(MessageContent.URLDetail(url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")) + let json = try encoder.encode(content) + let object = try decoder.decode(MessageContent.self, from: json) + + switch object { + case .imageUrl(let detail): + XCTAssertEqual(detail.url, "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg") + XCTAssertEqual(detail.detail, nil) + case .text(_): + XCTFail() + } + } + + func testImageURLDetailEncoding() throws { + let content = MessageContent.imageUrl(MessageContent.URLDetail(url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", detail: .high)) + let json = try encoder.encode(content) + let object = try decoder.decode(MessageContent.self, from: json) + + switch object { + case .imageUrl(let detail): + XCTAssertEqual(detail.url, "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg") + XCTAssertEqual(detail.detail, .high) + case .text(_): + XCTFail() + } + } + + func testImageDataEncoding() throws { + let content = MessageContent.imageUrl(MessageContent.URLDetail(url: "")) + let json = try encoder.encode(content) + let object = try decoder.decode(MessageContent.self, from: json) + + switch object { + case .imageUrl(let detail): + XCTAssertEqual(detail.url, "") + XCTAssertEqual(detail.detail, nil) + case .text(_): + XCTFail() + } + } +} diff --git a/Tests/CleverBirdTests/OpenAIChatThreadTests.swift b/Tests/CleverBirdTests/OpenAIChatThreadTests.swift index 7379e72..494a389 100644 --- a/Tests/CleverBirdTests/OpenAIChatThreadTests.swift +++ b/Tests/CleverBirdTests/OpenAIChatThreadTests.swift @@ -13,7 +13,7 @@ class OpenAIChatThreadTests: XCTestCase { XCTAssertEqual(2, chatThread.getMessages().count) XCTAssertEqual(1, chatThread.getNonSystemMessages().count) - XCTAssertEqual(userMessageContent, chatThread.getNonSystemMessages().first?.content) + XCTAssertEqual(userMessageContent, chatThread.getNonSystemMessages().first?.content?.description) } func testTokenCount() throws {